Newer
Older
from dataset import *
from methods import *
from models import *
def train_generator_supervised(gen, gen_opt, train_set, epochs):
Supervised pre-training for the generator
For each input token, the generator needs to predict the next token.
train_dl = DataLoader(dataset=train_set, batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_collate,
num_workers=DATALOADER_NUM_WORKERS)
for epoch in range(epochs):
sys.stdout.flush()
total_loss = 0
supervised_batch_count = len(train_dl)
for i, (sequence, rfcv, length) in enumerate(train_dl):
inp, target = one_token_shift_sequences(sequence,
start_letter=START_LETTER,
gpu=CUDA)
nll = nn.NLLLoss()
hidden = None
loss = 0
for j in range(inp.shape[1]):
out, hidden = gen(inp[:, j], hidden)
loss += nll(out, target[:, j])
# remove all hidden loss computation. The structure of the code must be visible right here.
total_loss += loss.data.item() / inp.shape[1]
average_loss = total_loss / supervised_batch_count
print(f"Generator supervised epoch {epoch}, average_train_NLL = {average_loss:.4f}")
def train_generator_PG(gen, gen_opt, dis, num_batches):
The generator is trained using predictions from the discriminator. No supervised data is used.
batch_size = BATCH_SIZE * 2 # old implementation
sequences, logits = gen.generate_sequences(batch_size)
sequences, length = auto_batch_target_length(sequences)
dis_preds = dis(sequences, length)
pg_loss = PolicyGradientLoss() # log(P(y_t|Y_1:Y_{t-1})) * Q
loss = pg_loss(logits, dis_preds)
# Ensure that the discriminator is not trained: gen_opt only has generator loss, and dis_opt.zero_grad() must
# be called before the discriminator update.
total_loss /= num_batches
print(f"Generator policy gradient loss: {total_loss:.4f}")
def train_discriminator(discriminator, dis_opt, real_train_set, real_valid_set, generator, d_steps, epochs):
"""
Training the discriminator on real_data_samples (positive) and generated samples from generator (negative).
Samples are drawn d_steps times, and the discriminator is trained for epochs epochs.
"""
# create validation set
fake_validation_sequences, _ = generator.generate_sequences(len(real_valid_set))
fake_valid_set = GeneratedDataset(fake_validation_sequences, random_truncate=True, coverage_vector=True)
print(next(iter(fake_valid_set)))
# TODO change this
valid_set = ConcatDataset(fake_valid_set, fake_valid_set)
valid_dl = DataLoader(dataset=valid_set, batch_size=BATCH_SIZE, shuffle=True, collate_fn=cov_vector_pad_collate,
num_workers=DATALOADER_NUM_WORKERS)
val_interval = 5
fake_training_sequences, _ = generator.generate_sequences(SUPERVISED_SIZE)
fake_train_set = GeneratedDataset(fake_training_sequences, random_truncate=True, coverage_vector=True)
# TODO change this
train_set = ConcatDataset(fake_train_set, fake_train_set)
train_dl = DataLoader(dataset=train_set, batch_size=BATCH_SIZE, shuffle=True, collate_fn=cov_vector_pad_collate,
num_workers=DATALOADER_NUM_WORKERS)
for epoch in range(epochs):
sys.stdout.flush()
total_loss = 0
total_acc = 0
for i, (sequence, target, length) in enumerate(train_dl):
if CUDA:
sequence = sequence.cuda()
target = target.cuda()
out = discriminator(sequence, length)
discrim_loss = DiscrimLoss()
loss = discrim_loss(out, target)
# mse = nn.MSELoss() # TODO fix
# loss = mse(out, target)
total_loss += loss.item()
total_acc += torch.sum((out[0] > 0.5) == (target[0] > 0.5)).item()
avg_loss = total_loss / len(train_dl)
avg_acc = total_acc / len(train_set)
print(
f"Discriminator dstep {d_step}, epoch {epoch}, avg training loss = {avg_loss:.4f}, accuracy = {avg_acc:.4f}")
if epoch % val_interval == 0:
with torch.no_grad():
val_loss = 0
val_acc = 0
pos_count = 0
pos_stats = torch.zeros((3,))
neg_stats = torch.zeros((3,))
for i, (sequence, target, length) in enumerate(valid_dl):
if CUDA:
sequence, target, pos_stats, neg_stats = sequence.cuda(), target.cuda(), \
pos_stats.cuda(), neg_stats.cuda()
out = discriminator(sequence, length)
discrim_loss = DiscrimLoss(reduction="none")
loss = discrim_loss(out, target)
pos_stats += (loss * target[:, 0].unsqueeze(1)).sum(0)
neg_stats += (loss * (1 - target[:, 0].unsqueeze(1))).sum(0)
pos_count += target[:, 0].sum(0)
val_loss += loss.mean().item()
val_acc += torch.sum((out[0] > 0.5) == (target[0] > 0.5)).item()
val_loss /= len(valid_dl)
val_acc /= len(valid_set)
pos_stats /= pos_count
neg_stats /= len(valid_set) - pos_count
print('Discriminator validation average_loss = %.4f, train_acc = %.4f' % (
total_loss, val_acc))
print(' val_pos_class = %.4f, val_pos_valid = %.4f, val_pos_cov = %.4f' % (
pos_stats[0], pos_stats[1], pos_stats[2]))
print(' val_neg_class = %.4f, val_neg_valid = %.4f, val_neg_cov = %.4f' % (
neg_stats[0], neg_stats[1], neg_stats[2]))
def main1(load=False):
supervised_samples = torch.load(supervised_sequences_path)
supervised_targets = torch.load(supervised_val_cov_path)
zork_data = SupervisedData(supervised_samples, supervised_targets)
train_set = zork_data.training_set
valid_set = zork_data.validation_set
gen = Generator(GEN_EMBEDDING_DIM, GEN_HIDDEN_DIM, VOCAB_SIZE, MAX_SEQ_LEN, gpu=CUDA)
dis = Discriminator(DIS_EMBEDDING_DIM, DIS_HIDDEN_DIM, VOCAB_SIZE, MAX_SEQ_LEN, gpu=CUDA)
if load:
gen.load_state_dict(torch.load(pretrained_gen_path))
dis.load_state_dict(torch.load(pretrained_dis_path))
print('Starting Generator Supervised Training...')
train_generator_supervised(gen, gen_optimizer, train_set,
SUPERVISED_TRAIN_EPOCHS)
# torch.save(gen.state_dict(), pretrained_gen_path)
print('Starting Discriminator Training...')
train_discriminator(dis, dis_optimizer, train_set, valid_set, gen, 10, 3)
# torch.save(dis.state_dict(), pretrained_dis_path)
print('\nStarting Adversarial Training...')
print('\n--------\nEPOCH %d\n--------' % (epoch + 1))
print('\nAdversarial Training Generator : ')
train_generator_PG(gen, gen_optimizer, dis, 1)
# TRAIN DISCRIMINATOR
print('\nAdversarial Training Discriminator : ')
train_discriminator(dis, dis_optimizer, train_set, valid_set, gen, 2, 2)
# torch.save(dis.state_dict(), pretrained_dis_path)
# torch.save(gen.state_dict(), pretrained_gen_path)
# # gen.load_state_dict(torch.load(pretrained_gen_path))
# # dis.load_state_dict(torch.load(pretrained_dis_path))
# generate and save examples as desired
file1 = open(vocab_path, 'r')
Lines = file1.readlines()
# Strips the newline character
for line in Lines:
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
# print(vocab_list)
# # TEST SAMPLE QUALITY
# test_inp = gen.generate_sequences(batch_size=20)
# # h = dis.init_hidden(test_inp.size()[0])
# test_out = dis(test_inp)
# print(test_out)
#
# # SAVE SAMPLES TO FILES
# test_samples = gen.generate_sequences(batch_size=128)
# test_out = dis(test_samples)
# test_samples = test_samples.cpu()
# samples_list = test_samples.numpy()
#
# ## OUTPUT GENERATED SAMPLES
#
# for i in range(len(samples_list)):
# # print(test_out[i][2])
# # if test_out[i][2] > 0.1:
# with open(output_path + str(i) + '.txt', 'w') as vocab_file:
# for j in samples_list[i]:
# if (j == 1):
# break
# if (j > len(vocab_list)):
# break
# vocab_file.write('%s\n' % vocab_list[j])
#
# MAIN
if __name__ == '__main__':
main1()