Skip to content
Snippets Groups Projects
Commit 593ecbca authored by fuchai's avatar fuchai
Browse files

Code connected.

A few issues left unaddressed:
The generator should not use the discriminator's output as loss.
Parallel get coverage.
Supervised data does not have coverage. Need to collect.
parent d37d781c
No related branches found
No related tags found
2 merge requests!2Merge for work. Discard merge branch.,!1Code not finished, but runnable. Use it as a base for work.
This commit is part of merge request !2. Comments created here will be created in the context of that merge request.
.*
GanCoverage/
GanSamples/
Temporary/
zork/
\ No newline at end of file
File moved
from __future__ import print_function
import sys
import torch.optim as optim
from dataset import *
from methods import *
from models import *
from parameter import *
def train_generator_supervised(gen, gen_opt, train_set, epochs):
"""
Supervised pre-training for the generator
For each input token, the generator needs to predict the next token.
"""
train_dl = DataLoader(dataset=train_set, batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_collate,
num_workers=DATALOADER_NUM_WORKERS)
for epoch in range(epochs):
sys.stdout.flush()
total_loss = 0
supervised_batch_count = len(train_dl)
for i, (sequence, rfcv, length) in enumerate(train_dl):
inp, target = one_token_shift_sequences(sequence,
start_letter=START_LETTER,
gpu=CUDA)
gen_opt.zero_grad()
nll = nn.NLLLoss()
hidden = None
loss = 0
for j in range(inp.shape[1]):
out, hidden = gen(inp[:, j], hidden)
loss += nll(out, target[:, j])
# remove all hidden loss computation. The structure of the code must be visible right here.
loss.backward()
gen_opt.step()
total_loss += loss.data.item() / inp.shape[1]
average_loss = total_loss / supervised_batch_count
print(f"Generator supervised epoch {epoch}, average_train_NLL = {average_loss:.4f}")
def train_generator_PG(gen, gen_opt, dis, num_batches):
"""
The generator is trained using predictions from the discriminator. No supervised data is used.
"""
# seq_len = 20
total_loss = 0
for batch in range(num_batches):
gen_opt.zero_grad()
batch_size = BATCH_SIZE * 2 # old implementation
sequences, logits = gen.generate_sequences(batch_size)
sequences, length = auto_batch_target_length(sequences)
dis_preds = dis(sequences, length)
pg_loss = PolicyGradientLoss() # log(P(y_t|Y_1:Y_{t-1})) * Q
loss = pg_loss(logits, dis_preds)
# Ensure that the discriminator is not trained: gen_opt only has generator loss, and dis_opt.zero_grad() must
# be called before the discriminator update.
total_loss += loss.item()
loss.backward()
gen_opt.step()
total_loss /= num_batches
print(f"Generator policy gradient loss: {total_loss:.4f}")
def train_discriminator(discriminator, dis_opt, real_train_set, real_valid_set, generator, d_steps, epochs):
"""
Training the discriminator on real_data_samples (positive) and generated samples from generator (negative).
Samples are drawn d_steps times, and the discriminator is trained for epochs epochs.
"""
# create validation set
fake_validation_sequences, _ = generator.generate_sequences(len(real_valid_set))
fake_valid_set = GeneratedDataset(fake_validation_sequences, random_truncate=True, coverage_vector=True)
print(next(iter(fake_valid_set)))
# TODO change this
valid_set = ConcatDataset(fake_valid_set, fake_valid_set)
valid_dl = DataLoader(dataset=valid_set, batch_size=BATCH_SIZE, shuffle=True, collate_fn=cov_vector_pad_collate,
num_workers=DATALOADER_NUM_WORKERS)
val_interval = 5
for d_step in range(d_steps):
fake_training_sequences, _ = generator.generate_sequences(SUPERVISED_SIZE)
fake_train_set = GeneratedDataset(fake_training_sequences, random_truncate=True, coverage_vector=True)
# TODO change this
train_set = ConcatDataset(fake_train_set, fake_train_set)
train_dl = DataLoader(dataset=train_set, batch_size=BATCH_SIZE, shuffle=True, collate_fn=cov_vector_pad_collate,
num_workers=DATALOADER_NUM_WORKERS)
for epoch in range(epochs):
sys.stdout.flush()
total_loss = 0
total_acc = 0
for i, (sequence, target, length) in enumerate(train_dl):
if CUDA:
sequence = sequence.cuda()
target = target.cuda()
dis_opt.zero_grad()
out = discriminator(sequence, length)
discrim_loss = DiscrimLoss()
loss = discrim_loss(out, target)
# mse = nn.MSELoss() # TODO fix
# loss = mse(out, target)
loss.backward()
dis_opt.step()
total_loss += loss.item()
total_acc += torch.sum((out[0] > 0.5) == (target[0] > 0.5)).item()
avg_loss = total_loss / len(train_dl)
avg_acc = total_acc / len(train_set)
print(
f"Discriminator dstep {d_step}, epoch {epoch}, avg training loss = {avg_loss:.4f}, accuracy = {avg_acc:.4f}")
if epoch % val_interval == 0:
with torch.no_grad():
val_loss = 0
val_acc = 0
pos_count = 0
pos_stats = torch.zeros((3,))
neg_stats = torch.zeros((3,))
for i, (sequence, target, length) in enumerate(valid_dl):
if CUDA:
sequence, target, pos_stats, neg_stats = sequence.cuda(), target.cuda(), \
pos_stats.cuda(), neg_stats.cuda()
out = discriminator(sequence, length)
discrim_loss = DiscrimLoss(reduction="none")
loss = discrim_loss(out, target)
pos_stats += (loss * target[:, 0].unsqueeze(1)).sum(0)
neg_stats += (loss * (1 - target[:, 0].unsqueeze(1))).sum(0)
pos_count += target[:, 0].sum(0)
val_loss += loss.mean().item()
val_acc += torch.sum((out[0] > 0.5) == (target[0] > 0.5)).item()
val_loss /= len(valid_dl)
val_acc /= len(valid_set)
pos_stats /= pos_count
neg_stats /= len(valid_set) - pos_count
print('Discriminator validation average_loss = %.4f, train_acc = %.4f' % (
total_loss, val_acc))
print(' val_pos_class = %.4f, val_pos_valid = %.4f, val_pos_cov = %.4f' % (
pos_stats[0], pos_stats[1], pos_stats[2]))
print(' val_neg_class = %.4f, val_neg_valid = %.4f, val_neg_cov = %.4f' % (
neg_stats[0], neg_stats[1], neg_stats[2]))
def main1(load=False):
# TODO the supervised targets do not contain coverage vectors
supervised_samples = torch.load(supervised_sequences_path)
supervised_targets = torch.load(supervised_val_cov_path)
zork_data = SupervisedData(supervised_samples, supervised_targets)
train_set = zork_data.training_set
valid_set = zork_data.validation_set
gen = Generator(GEN_EMBEDDING_DIM, GEN_HIDDEN_DIM, VOCAB_SIZE, MAX_SEQ_LEN, gpu=CUDA)
dis = Discriminator(DIS_EMBEDDING_DIM, DIS_HIDDEN_DIM, VOCAB_SIZE, MAX_SEQ_LEN, gpu=CUDA)
if load:
gen.load_state_dict(torch.load(pretrained_gen_path))
dis.load_state_dict(torch.load(pretrained_dis_path))
if CUDA:
gen = gen.cuda()
dis = dis.cuda()
print('Starting Generator Supervised Training...')
gen_optimizer = optim.Adam(gen.parameters(), lr=1e-2)
train_generator_supervised(gen, gen_optimizer, train_set,
SUPERVISED_TRAIN_EPOCHS)
# torch.save(gen.state_dict(), pretrained_gen_path)
print('Starting Discriminator Training...')
dis_optimizer = optim.Adagrad(dis.parameters())
train_discriminator(dis, dis_optimizer, train_set, valid_set, gen, 10, 3)
# torch.save(dis.state_dict(), pretrained_dis_path)
print('\nStarting Adversarial Training...')
for epoch in range(ADV_TRAIN_EPOCHS):
print('\n--------\nEPOCH %d\n--------' % (epoch + 1))
# TRAIN GENERATOR
print('\nAdversarial Training Generator : ')
sys.stdout.flush()
train_generator_PG(gen, gen_optimizer, dis, 1)
# TRAIN DISCRIMINATOR
print('\nAdversarial Training Discriminator : ')
train_discriminator(dis, dis_optimizer, train_set, valid_set, gen, 2, 2)
# torch.save(dis.state_dict(), pretrained_dis_path)
# torch.save(gen.state_dict(), pretrained_gen_path)
# # gen.load_state_dict(torch.load(pretrained_gen_path))
# # dis.load_state_dict(torch.load(pretrained_dis_path))
# generate and save examples as desired
# LOAD VOCAB
vocab_list = []
file1 = open(vocab_path, 'r')
Lines = file1.readlines()
count = 0
# Strips the newline character
for line in Lines:
vocab_list.append(line.strip())
# print(vocab_list)
# # TEST SAMPLE QUALITY
# test_inp = gen.generate_sequences(batch_size=20)
# # h = dis.init_hidden(test_inp.size()[0])
# test_out = dis(test_inp)
# print(test_out)
#
# # SAVE SAMPLES TO FILES
# test_samples = gen.generate_sequences(batch_size=128)
# test_out = dis(test_samples)
# test_samples = test_samples.cpu()
# samples_list = test_samples.numpy()
#
# ## OUTPUT GENERATED SAMPLES
#
# for i in range(len(samples_list)):
# # print(test_out[i][2])
# # if test_out[i][2] > 0.1:
# with open(output_path + str(i) + '.txt', 'w') as vocab_file:
# for j in samples_list[i]:
# if (j == 1):
# break
# if (j > len(vocab_list)):
# break
# vocab_file.write('%s\n' % vocab_list[j])
#
# MAIN
if __name__ == '__main__':
main1()
File moved
File moved
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
# TODO check parameter initialization
class Discriminator(nn.Module):
def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, gpu=False, dropout=0.2):
super(Discriminator, self).__init__()
self.hidden_dim = hidden_dim
self.embedding_dim = embedding_dim
self.max_seq_len = max_seq_len
self.gpu = gpu
self.embeddings = nn.Embedding(vocab_size, embedding_dim)
self.gru = nn.GRU(embedding_dim, hidden_dim, num_layers=2, bidirectional=True, dropout=dropout)
self.gru2hidden = nn.Linear(2 * 2 * hidden_dim, hidden_dim)
self.gru_out_linear = nn.Linear(2 * hidden_dim, hidden_dim)
self.dropout_linear = nn.Dropout(p=dropout)
self.hidden2out = nn.Linear(hidden_dim, 1)
def forward(self, input, length, hidden=None):
"""
:param input: whole sequence, batch_size x seq_len
:param hidden: default to zeros
:return: (batch_size, 3) probability, validity, coverage, [0,1] interval
"""
emb = self.embeddings(input) # batch_size x seq_len x embedding_dim
# emb = emb.permute(1, 0, 2) # seq_len x batch_size x embedding_dim
# TODO pack tensors here
emb_packed = pack_padded_sequence(emb, length, batch_first=True, enforce_sorted=False)
# Because self.gru automatically initializes hidden
out, _ = self.gru(emb_packed, hidden)
out, lengths = pad_packed_sequence(out, batch_first=True)
outt = out[torch.arange(0, lengths.shape[0]), lengths - 1, :] # (batch_size, 2*hidden_dim)
out = self.gru_out_linear(outt) # batch_size x 4*hidden_dim
out = torch.tanh(out)
out = self.dropout_linear(out)
out = self.hidden2out(out) # batch_size x 3
out = torch.sigmoid(out)
return out
class Generator(nn.Module):
def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, gpu=False, oracle_init=False):
super(Generator, self).__init__()
self.hidden_dim = hidden_dim
self.embedding_dim = embedding_dim
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.gpu = gpu
self.embeddings = nn.Embedding(vocab_size, embedding_dim)
self.gru = nn.GRU(embedding_dim, hidden_dim)
self.gru2out = nn.Linear(hidden_dim, vocab_size)
# initialise oracle network with N(0,1)
# otherwise variance of initialisation is very small => high NLL for data sampled from the same model
if oracle_init:
for p in self.parameters():
init.normal_(p, 0, 1)
def forward(self, inp, hidden=None):
"""
Embeds input and applies GRU one token at a time (seq_len = 1)
:param inp: a batch of tokens, (batch_size,)
:param hidden:
:return: out: log probability for each vocab class, (batch_size, vocab_size)
hidden: hidden vector to be passed to the next time step, (1, batch_size, hidden_dim)
"""
if hidden is None:
# if without xavier_uniform, all sequences generated will be the same
hidden = torch.empty((1, inp.size(0), self.hidden_dim), device=inp.device)
init.xavier_uniform_(hidden)
emb = self.embeddings(inp) # batch_size x embedding_dim
emb = emb.unsqueeze(0) # 1 x batch_size x embedding_dim
out, hidden = self.gru(emb, hidden) # 1 x batch_size x hidden_dim (out)
out = self.gru2out(out.view(-1, self.hidden_dim)) # batch_size x vocab_size
out = F.log_softmax(out, dim=1)
return out, hidden
def generate_sequences(self, batch_size, start_letter=0):
"""
Samples the network and returns batch_size samples of self.length max_seq_len.
:param batch_size:
:param start_letter:
:return: sequences: batch_size x self.max_seq_length (a sampled sequence in each row)
logits: the log probability of the sampled token, used for PG loss
"""
inp = torch.LongTensor([start_letter] * batch_size)
sequences = torch.zeros(batch_size, self.max_seq_len).long()
logits = torch.zeros(batch_size, self.max_seq_len)
if self.cuda:
sequences = sequences.cuda()
inp = inp.cuda()
logits = logits.cuda()
h = None
for i in range(self.max_seq_len):
out, h = self(inp, h) # out: num_samples x vocab_size
# is autograd successfully passing the sampling?
tokens = torch.multinomial(torch.exp(out), 1).squeeze(1) # num_samples x 1 (sampling from each row)
sequences[:, i] = tokens
logits[:, i] = out[torch.arange(0, batch_size), tokens]
inp = tokens
return sequences, logits.contiguous()
class PolicyGradientLoss(nn.Module):
def __init__(self):
super(PolicyGradientLoss, self).__init__()
self.weights = [0.4, 0.10, 0.50]
# TODO lacks documentation as to which is which
# self.fake_weight=0.4
# self.val_weight=0.1
# self.cov_weight=0.5
def forward(self, logits, dis_preds):
# TODO punish repeat targets?
# how to define repeat target?
sum_logits = logits.sum(1)
loss = torch.matmul(sum_logits.cuda(), dis_preds)
weighted_sum = self.weights[0] * loss[0] + self.weights[1] * loss[1] + self.weights[2] * loss[2]
weighted_sum = - weighted_sum
return weighted_sum / logits.shape[0] / logits.shape[1]
class DiscrimLoss(nn.Module):
"""
The real, fake predictions are log
The validation and coverage are L1
If validation or coverage skew towards 0 or 1, consider log scale.
"""
def __init__(self, reduction="mean"):
super(DiscrimLoss, self).__init__()
self.reduction = reduction
self.l1 = nn.SmoothL1Loss(reduction=self.reduction)
self.bce = nn.BCELoss(reduction=self.reduction)
def forward(self, out, target):
"""
:param out: (batch_size, 1) real/fake
:param target:
:return:
"""
nll = self.bce(out, target)
return nll
\ No newline at end of file
File moved
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment