Skip to content
Snippets Groups Projects
main_665.py 12.9 KiB
Newer Older
zaglanz's avatar
zaglanz committed
from __future__ import print_function
from math import ceil
import numpy as np
import sys
import pdb

import torch
import torch.optim as optim
import torch.nn as nn

from models_665 import Generator
from models_665 import Discriminator
import methods_665

CUDA = True
VOCAB_SIZE = 500                                                                                                ## adjust hyperparameters as necessary
MAX_SEQ_LEN = 500
START_LETTER = 0
BATCH_SIZE = 32
MLE_TRAIN_EPOCHS = 30
ADV_TRAIN_EPOCHS = 20
POS_NEG_SAMPLES = 704

GEN_EMBEDDING_DIM = 32
GEN_HIDDEN_DIM = 32
DIS_EMBEDDING_DIM = 64
DIS_HIDDEN_DIM = 32

#oracle_samples_path = './oracle_samples.trc'
oracle_samples_path = './zorkdata_long_500.pt' ## INPUT DATA                                                                    ## set dataset and model parameter paths
oracle_targets_path = './zorkdata_long_500_target.pt' ## LABEL DATA
oracle_lengths_path = './zorkdata_lengths_500.pt' ## INPUT LENGTHS
oracle_state_dict_path = './oracle_ZORK20_60_v2.trc' ## can ignore
pretrained_gen_path = './gen_ZORK500_MLEtrain_v2.trc' ## GENERATOR SAVE FILE
pretrained_dis_path = './dis_ZORK500_pretrain_v2.trc' ## DISCRIMINATOR SAVE FILE
vocab_path = 'zork_vocab.txt' ## VOCAB LIST
output_path = 'generated_inputs/gen_input' ## output generatred samples

def train_generator_MLE(gen, gen_opt, oracle, real_data_samples, real_data_lengths, epochs):
    """
    Max Likelihood Pretraining for the generator
    """
    for epoch in range(epochs):
        print('epoch %d : ' % (epoch + 1), end='')
        sys.stdout.flush()
        total_loss = 0

        for i in range(0, POS_NEG_SAMPLES, BATCH_SIZE):
            inp, inp_lengths, target = methods_665.prepare_generator_batch(real_data_samples[i:i + BATCH_SIZE], real_data_lengths[i:i + BATCH_SIZE], start_letter=START_LETTER,
                                                          gpu=CUDA)    
            
          #  if((epoch == 0) and (i == 0)):
            #    print(inp, inp_lengths, target)
            
            gen_opt.zero_grad()
            loss = gen.batchNLLLoss(inp, inp_lengths, target)
            loss.backward()
            gen_opt.step()

            total_loss += loss.data.item()

            if (i / BATCH_SIZE) % ceil(
                            ceil(POS_NEG_SAMPLES / float(BATCH_SIZE)) / 10.) == 0:  # roughly every 10% of an epoch
                print('.', end='')
                sys.stdout.flush()

        # each loss in a batch is loss per sample
        total_loss = total_loss / ceil(POS_NEG_SAMPLES / float(BATCH_SIZE)) / MAX_SEQ_LEN

        # sample from generator and compute oracle NLL
        oracle_loss = methods_665.batchwise_oracle_nll(gen, oracle, POS_NEG_SAMPLES, BATCH_SIZE, MAX_SEQ_LEN,
                                                   start_letter=START_LETTER, gpu=CUDA)

        print(' average_train_NLL = %.4f, oracle_sample_NLL = %.4f' % (total_loss, oracle_loss))


def train_generator_PG(gen, gen_opt, oracle, dis, num_batches):
    """
    The generator is trained using policy gradients, using the reward from the discriminator.
    Training is done for num_batches batches.
    """
    seq_len = 20

    for batch in range(num_batches):
        s = gen.sample(BATCH_SIZE*2)        # 64 works best
        inp, inp_lengths, target = methods_665.prepare_generator_batch(s, seq_len, start_letter=START_LETTER, gpu=CUDA)
        rewards = dis.batchClassify(target)

        gen_opt.zero_grad()
        pg_loss = gen.batchPGLoss(inp, target, rewards)
        pg_loss.backward()
        gen_opt.step()

    # sample from generator and compute oracle NLL
    #oracle_loss = batchwise_oracle_nll(gen, oracle, POS_NEG_SAMPLES, BATCH_SIZE, MAX_SEQ_LEN,
                               #                    start_letter=START_LETTER, gpu=CUDA)

    #print(' oracle_sample_NLL = %.4f' % oracle_loss)


def train_discriminator(discriminator, dis_opt, real_data_samples, generator, oracle, d_steps, epochs, real_data_targets, gan_training=False):
    """
    Training the discriminator on real_data_samples (positive) and generated samples from generator (negative).
    Samples are drawn d_steps times, and the discriminator is trained for epochs epochs.
    """

    # generating a small validation set before training (using oracle and generator)
    pos_val = real_data_samples[-64:]
    seq_len = len(pos_val[0])
    neg_val = generator.sample(64)
    pos_targ = real_data_targets[-64:]
   # print(pos_val, pos_targ)
    val_inp, val_target = methods_665.prepare_discriminator_data(pos_val, neg_val, gpu=CUDA, real_target=pos_targ)
    
    
    for d_step in range(d_steps):
        s = batchwise_sample(generator, POS_NEG_SAMPLES, BATCH_SIZE)
        dis_inp, dis_target = methods_665.prepare_discriminator_data(real_data_samples, torch.empty(0, seq_len).cuda(), gpu=CUDA, real_target=real_data_targets)
        dis_ninp, dis_ntarget = methods_665.prepare_discriminator_data(torch.empty(0, seq_len).cuda(), s, gpu=CUDA)
        for epoch in range(epochs):
            print('d-step %d epoch %d : ' % (d_step + 1, epoch + 1), end='')
            sys.stdout.flush()
            total_loss = 0
            total_acc = 0

            for i in range(0, POS_NEG_SAMPLES, BATCH_SIZE):
                inp, target = dis_inp[i:i + BATCH_SIZE], dis_target[i:i + BATCH_SIZE]
                ninp, ntarget = dis_ninp[i:i + BATCH_SIZE], dis_ntarget[i:i + BATCH_SIZE]
                dis_opt.zero_grad()
                out = discriminator.batchClassify(inp)
                nout = discriminator.batchClassify(ninp)
                loss_fn = nn.MSELoss()                                      ## IS USING MSE HERE OK!?!??!?!
              #  if(gan_training):                    #bandiad due to lack of effective oracle
              #      for x in target:
              #          if x[0] == 0:
              #              target[1:] = out[1:]
                t_out = torch.stack((out[0], out[1], out[2], nout[0]))
                t_target = torch.stack((target[0], target[1], target[2], ntarget[0]))
                loss = loss_fn(t_out, t_target)
             #   loss += loss_fn(nout[0], ntarget[0])
                loss.backward()
                dis_opt.step()

                total_loss += loss.data.item()
   
                
                total_acc += torch.sum((out>0.5)==(target>0.5)).data.item() # ZACH TODO: only use first column here? 

                if (i / BATCH_SIZE) % ceil(ceil(2 * POS_NEG_SAMPLES / float(
                        BATCH_SIZE)) / 10.) == 0:  # roughly every 10% of an epoch
                    print('.', end='')
                    sys.stdout.flush()

            total_loss /= ceil(2 * POS_NEG_SAMPLES / float(BATCH_SIZE))
            total_acc /= float(2 * POS_NEG_SAMPLES)

            val_pred = discriminator.batchClassify(val_inp)
            pred_pos = val_pred[0:64]
            pred_neg = val_pred[-64:]
            targ_pos = val_target[0:64]
            targ_neg = val_target[-64:]
            val_pos_class, val_pos_valid, val_pos_cov, val_neg_class, val_neg_valid, val_neg_cov = 0, 0, 0, 0, 0, 0
            for x in range(len(pred_pos)):
                val_pos_class += abs(pred_pos[x][0] - targ_pos[x][0])
                val_pos_valid += abs(pred_pos[x][1] - targ_pos[x][1])
                val_pos_cov += abs(pred_pos[x][2] - targ_pos[x][2])
                val_neg_class += abs(pred_neg[x][0] - targ_neg[x][0])
                val_neg_valid += abs(pred_neg[x][1] - targ_neg[x][1])
                val_neg_cov += abs(pred_neg[x][2] - targ_neg[x][2])                
            
            print(' average_loss = %.4f, train_acc = %.4f' % (
                total_loss, total_acc))
            print(' val_pos_class = %.4f, val_pos_valid = %.4f, val_pos_cov = %.4f' % (val_pos_class / 64., val_pos_valid / 64., val_pos_cov / 64.))
            print(' val_neg_class = %.4f, val_neg_valid = %.4f, val_neg_cov = %.4f' % (val_neg_class / 64., val_neg_valid / 64., val_neg_cov / 64.))

# MAIN
if __name__ == '__main__':

    
    oracle = Generator(GEN_EMBEDDING_DIM, GEN_HIDDEN_DIM, VOCAB_SIZE, MAX_SEQ_LEN, gpu=CUDA, oracle_init=True)
    #oracle.load_state_dict(torch.load(oracle_state_dict_path))
    oracle_samples = torch.load(oracle_samples_path)
    oracle_targets = torch.load(oracle_targets_path)
    oracle_lengths = torch.load(oracle_lengths_path)
    # a new oracle can be generated by passing oracle_init=True in the generator constructor
    # samples for the new oracle can be generated using helpers.batchwise_sample()

    gen = Generator(GEN_EMBEDDING_DIM, GEN_HIDDEN_DIM, VOCAB_SIZE, MAX_SEQ_LEN, gpu=CUDA)
    dis = Discriminator(DIS_EMBEDDING_DIM, DIS_HIDDEN_DIM, VOCAB_SIZE, MAX_SEQ_LEN, gpu=CUDA)

    if CUDA:
        oracle = oracle.cuda()
        gen = gen.cuda()
        dis = dis.cuda()
        oracle_samples = oracle_samples.cuda()


    # GENERATOR MLE TRAINING
    print('Starting Generator MLE Training...')
    gen_optimizer = optim.Adam(gen.parameters(), lr=1e-2)
    
    train_generator_MLE(gen, gen_optimizer, oracle, oracle_samples, oracle_lengths, MLE_TRAIN_EPOCHS)                   ## comment out if loading gen
    torch.save(gen.state_dict(), pretrained_gen_path)                                                                   ## comment out if loading gen
    # gen.load_state_dict(torch.load(pretrained_gen_path))                                                              ## uncomment if loading pre-trained gen


    # PRETRAIN DISCRIMINATOR
    print('\nStarting Discriminator Training...')
    dis_optimizer = optim.Adagrad(dis.parameters())

    train_discriminator(dis, dis_optimizer, oracle_samples, gen, oracle, 30, 3, oracle_targets, gan_training=True)      ## comment out if loading dis
    torch.save(dis.state_dict(), pretrained_dis_path)                                                                   ## comment out if loading dis
    # dis.load_state_dict(torch.load(pretrained_dis_path))                                                              ## uncomment if loading pre-trained dis                    



    # ADVERSARIAL TRAINING                                                                                  
    print('\nStarting Adversarial Training...')                                                                         ## adv training, comment out if you're loading an 
    oracle_loss = methods_665.batchwise_oracle_nll(gen, oracle, POS_NEG_SAMPLES, BATCH_SIZE, MAX_SEQ_LEN,                           ## already adv-trained model
                                               start_letter=START_LETTER, gpu=CUDA)                                     ## oracle loss is a metric i was not using and can be ignored
    print('\nInitial Oracle Sample Loss : %.4f' % oracle_loss)
    for epoch in range(ADV_TRAIN_EPOCHS):
        print('\n--------\nEPOCH %d\n--------' % (epoch+1))
        # TRAIN GENERATOR
        print('\nAdversarial Training Generator : ', end='')
        sys.stdout.flush()
        train_generator_PG(gen, gen_optimizer, oracle, dis, 1)

        # TRAIN DISCRIMINATOR
        print('\nAdversarial Training Discriminator : ')
        train_discriminator(dis, dis_optimizer, oracle_samples, gen, oracle, 2, 2, oracle_targets)
        
    torch.save(dis.state_dict(), pretrained_dis_path)                                                           ## saving post-adv training if desired
    torch.save(gen.state_dict(), pretrained_gen_path)                                                           
    # gen.load_state_dict(torch.load(pretrained_gen_path))                                                      ## load adv trained models if desired
    # dis.load_state_dict(torch.load(pretrained_dis_path))                                                                       


    #generate and save examples as desired 
    
    #    LOAD VOCAB
    
    vocab_list = []
    file1 = open(vocab_path, 'r') 
    Lines = file1.readlines() 
    count = 0
    # Strips the newline character 
    for line in Lines: 
        vocab_list.append(line.strip())
    #print(vocab_list)

    #  TEST SAMPLE QUALITY                                       ## predicted labels should have decent coverage/validity
    test_inp = gen.sample(20)
    #h = dis.init_hidden(test_inp.size()[0])
    test_out = dis.batchClassify(test_inp)
    print(test_out)
    
    
    #   SAVE SAMPLES TO FILES
    test_samples = gen.sample(128)                             ## choose number of samples to save
    test_out = dis.batchClassify(test_samples)
    test_samples = test_samples.cpu()
    samples_list = test_samples.numpy()
            
    ## OUTPUT GENERATED SAMPLES
            
    for i in range(len(samples_list)):
        #print(test_out[i][2])
        #if test_out[i][2] > 0.1:                              ## can assign test to only save good (predicted) files
        with open(output_path+str(i)+'.txt', 'w') as vocab_file:                               ## name output sequence files 
            for j in samples_list[i]:
                if(j == 1):
                    break
                if(j > len(vocab_list)):
                    break
                vocab_file.write('%s\n' % vocab_list[j])