Skip to content
Snippets Groups Projects
main_665.py 9.38 KiB
Newer Older
zaglanz's avatar
zaglanz committed
from __future__ import print_function
zaglanz's avatar
zaglanz committed
import sys

import torch.optim as optim

from dataset import *
from methods import *
from models import *


def train_generator_supervised(gen, gen_opt, train_set, epochs):
zaglanz's avatar
zaglanz committed
    """
    Supervised pre-training for the generator
    For each input token, the generator needs to predict the next token.
zaglanz's avatar
zaglanz committed
    """
    train_dl = DataLoader(dataset=train_set, batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_collate,
                          num_workers=DATALOADER_NUM_WORKERS)

zaglanz's avatar
zaglanz committed
    for epoch in range(epochs):
        sys.stdout.flush()
        total_loss = 0
        supervised_batch_count = len(train_dl)

        for i, (sequence, rfcv, length) in enumerate(train_dl):
            inp, target = one_token_shift_sequences(sequence,
                                                    start_letter=START_LETTER,
                                                    gpu=CUDA)
zaglanz's avatar
zaglanz committed

            gen_opt.zero_grad()
            nll = nn.NLLLoss()
            hidden = None
            loss = 0
            for j in range(inp.shape[1]):
                out, hidden = gen(inp[:, j], hidden)
                loss += nll(out, target[:, j])

            # remove all hidden loss computation. The structure of the code must be visible right here.
zaglanz's avatar
zaglanz committed
            loss.backward()
            gen_opt.step()

            total_loss += loss.data.item() / inp.shape[1]
zaglanz's avatar
zaglanz committed

        average_loss = total_loss / supervised_batch_count
        print(f"Generator supervised epoch {epoch}, average_train_NLL = {average_loss:.4f}")
zaglanz's avatar
zaglanz committed


def train_generator_PG(gen, gen_opt, dis, num_batches):
zaglanz's avatar
zaglanz committed
    """
    The generator is trained using predictions from the discriminator. No supervised data is used.
zaglanz's avatar
zaglanz committed
    """
    # seq_len = 20
    pg_loss = 0
zaglanz's avatar
zaglanz committed
    for batch in range(num_batches):
        gen_opt.zero_grad()
        batch_size = BATCH_SIZE * 2  # old implementation
        sequences, logits = gen.generate_sequences(batch_size)
        sequences, length = auto_batch_target_length(sequences)
        dis_preds = dis (sequences, length)
        adv_loss = PolicyGradientLoss() # log(P(y_t|Y_1:Y_{t-1})) * Q
        loss = adv_loss(logits, dis_preds)
        # Ensure that the discriminator is not trained: gen_opt only has generator loss, and dis_opt.zero_grad() must
        # be called before the discriminator update.
        pg_loss += loss.item()
        loss.backward()
zaglanz's avatar
zaglanz committed
        gen_opt.step()

    pg_loss /= num_batches
    print(f"Generator policy gradient loss: {pg_loss:.4f}")
zaglanz's avatar
zaglanz committed


def train_discriminator(discriminator, dis_opt, real_train_set, real_valid_set, generator, d_steps, epochs):
zaglanz's avatar
zaglanz committed
    """
    Training the discriminator on real_data_samples (positive) and generated samples from generator (negative).
    Samples are drawn d_steps times, and the discriminator is trained for epochs epochs.
    """

    # create validation set
    fake_validation_sequences, _ = generator.generate_sequences(len(real_valid_set))
    fake_valid_set = GeneratedDataset(fake_validation_sequences, random_truncate=True)
    valid_set = ConcatDataset(real_valid_set, fake_valid_set)
    valid_dl = DataLoader(dataset=valid_set, batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_collate,
                          num_workers=DATALOADER_NUM_WORKERS)

    val_interval = 5

zaglanz's avatar
zaglanz committed
    for d_step in range(d_steps):
        fake_training_sequences, _ = generator.generate_sequences(SUPERVISED_SIZE)
        fake_train_set = GeneratedDataset(fake_training_sequences, random_truncate=True)
        train_set = ConcatDataset(fake_train_set, real_train_set)
        train_dl = DataLoader(dataset=train_set, batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_collate,
                              num_workers=DATALOADER_NUM_WORKERS)

zaglanz's avatar
zaglanz committed
        for epoch in range(epochs):
            sys.stdout.flush()
            total_loss = 0
            total_acc = 0

            for i, (sequence, target, length) in enumerate(train_dl):
                if CUDA:
                    sequence = sequence.cuda()
                    target = target.cuda()

zaglanz's avatar
zaglanz committed
                dis_opt.zero_grad()
                out = discriminator(sequence, length)
                mse = nn.MSELoss()  # TODO fix
                loss = mse(out, target)
zaglanz's avatar
zaglanz committed
                loss.backward()
                dis_opt.step()
                total_loss += loss.item()
                total_acc += torch.sum((out[0] > 0.5) == (target[0] > 0.5)).item()

            avg_loss = total_loss / len(train_dl)
            avg_acc = total_acc / len(train_set)
            print(
                f"Discriminator dstep {d_step}, epoch {epoch}, avg training loss = {avg_loss:.4f}, accuracy = {avg_acc:.4f}")

            if epoch % val_interval == 0:
                with torch.no_grad():
                    val_loss = 0
                    val_acc = 0
                    pos_count = 0
                    pos_stats = torch.zeros((3,))
                    neg_stats = torch.zeros((3,))
                    for i, (sequence, target, length) in enumerate(valid_dl):
                        if CUDA:
                            sequence, target, pos_stats, neg_stats = sequence.cuda(), target.cuda(), \
                                                                     pos_stats.cuda(), neg_stats.cuda()
                        out = discriminator(sequence, length)
                        mse = nn.MSELoss(reduction="none")
                        loss = mse(out, target)
                        pos_stats += (loss * target[:, 0].unsqueeze(1)).sum(0)
                        neg_stats += (loss * (1 - target[:, 0].unsqueeze(1))).sum(0)
                        pos_count += target[:, 0].sum(0)
                        val_loss += loss.mean().item()
                        val_acc += torch.sum((out[0] > 0.5) == (target[0] > 0.5)).item()

                    val_loss /= len(valid_dl)
                    val_acc /= len(valid_set)
                    pos_stats /= pos_count
                    neg_stats /= len(valid_set) - pos_count

                    print('Discriminator validation average_loss = %.4f, train_acc = %.4f' % (
                        total_loss, val_acc))
                    print('                         val_pos_class = %.4f, val_pos_valid = %.4f, val_pos_cov = %.4f' % (
                        pos_stats[0], pos_stats[1], pos_stats[2]))
                    print('                         val_neg_class = %.4f, val_neg_valid = %.4f, val_neg_cov = %.4f' % (
                        neg_stats[0], neg_stats[1], neg_stats[2]))



def main1(load=False):
    supervised_samples = torch.load(supervised_sequences_path)
    supervised_targets = torch.load(supervised_val_cov_path)

    zork_data = SupervisedData(supervised_samples, supervised_targets)
    train_set = zork_data.training_set
    valid_set = zork_data.validation_set
zaglanz's avatar
zaglanz committed

    gen = Generator(GEN_EMBEDDING_DIM, GEN_HIDDEN_DIM, VOCAB_SIZE, MAX_SEQ_LEN, gpu=CUDA)
    dis = Discriminator(DIS_EMBEDDING_DIM, DIS_HIDDEN_DIM, VOCAB_SIZE, MAX_SEQ_LEN, gpu=CUDA)

    if load:
        gen.load_state_dict(torch.load(pretrained_gen_path))
        dis.load_state_dict(torch.load(pretrained_dis_path))

zaglanz's avatar
zaglanz committed
    if CUDA:
        gen = gen.cuda()
        dis = dis.cuda()

    print('Starting Generator Supervised Training...')
zaglanz's avatar
zaglanz committed
    gen_optimizer = optim.Adam(gen.parameters(), lr=1e-2)
    train_generator_supervised(gen, gen_optimizer, train_set,
                               SUPERVISED_TRAIN_EPOCHS)
    # torch.save(gen.state_dict(), pretrained_gen_path)
zaglanz's avatar
zaglanz committed

    print('Starting Discriminator Training...')
zaglanz's avatar
zaglanz committed
    dis_optimizer = optim.Adagrad(dis.parameters())
    train_discriminator(dis, dis_optimizer, train_set, valid_set, gen, 10, 3)
    # torch.save(dis.state_dict(), pretrained_dis_path)
zaglanz's avatar
zaglanz committed

    print('\nStarting Adversarial Training...')
zaglanz's avatar
zaglanz committed
    for epoch in range(ADV_TRAIN_EPOCHS):
        print('\n--------\nEPOCH %d\n--------' % (epoch + 1))
zaglanz's avatar
zaglanz committed
        # TRAIN GENERATOR
        print('\nAdversarial Training Generator : ')
zaglanz's avatar
zaglanz committed
        sys.stdout.flush()
        train_generator_PG(gen, gen_optimizer, dis, 1)
zaglanz's avatar
zaglanz committed

        # TRAIN DISCRIMINATOR
        print('\nAdversarial Training Discriminator : ')
        train_discriminator(dis, dis_optimizer, train_set, valid_set, gen, 2, 2)

    # torch.save(dis.state_dict(), pretrained_dis_path)
    # torch.save(gen.state_dict(), pretrained_gen_path)
    # # gen.load_state_dict(torch.load(pretrained_gen_path))
    # # dis.load_state_dict(torch.load(pretrained_dis_path))
zaglanz's avatar
zaglanz committed

    # generate and save examples as desired
zaglanz's avatar
zaglanz committed

    #    LOAD VOCAB
    vocab_list = []
    file1 = open(vocab_path, 'r')
    Lines = file1.readlines()
zaglanz's avatar
zaglanz committed
    count = 0
    # Strips the newline character
    for line in Lines:
zaglanz's avatar
zaglanz committed
        vocab_list.append(line.strip())
    # print(vocab_list)

    # #  TEST SAMPLE QUALITY
    # test_inp = gen.generate_sequences(batch_size=20)
    # # h = dis.init_hidden(test_inp.size()[0])
    # test_out = dis(test_inp)
    # print(test_out)
    #
    # #   SAVE SAMPLES TO FILES
    # test_samples = gen.generate_sequences(batch_size=128)
    # test_out = dis(test_samples)
    # test_samples = test_samples.cpu()
    # samples_list = test_samples.numpy()
    #
    # ## OUTPUT GENERATED SAMPLES
    #
    # for i in range(len(samples_list)):
    #     # print(test_out[i][2])
    #     # if test_out[i][2] > 0.1:
    #     with open(output_path + str(i) + '.txt', 'w') as vocab_file:
    #         for j in samples_list[i]:
    #             if (j == 1):
    #                 break
    #             if (j > len(vocab_list)):
    #                 break
    #             vocab_file.write('%s\n' % vocab_list[j])
    #

# MAIN
if __name__ == '__main__':
    main1()