Newer
Older
from dataset import *
from methods import *
from models import *
def train_generator_supervised(gen, gen_opt, train_set, epochs):
Supervised pre-training for the generator
For each input token, the generator needs to predict the next token.
train_dl = DataLoader(dataset=train_set, batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_collate,
num_workers=DATALOADER_NUM_WORKERS)
for epoch in range(epochs):
sys.stdout.flush()
total_loss = 0
supervised_batch_count = len(train_dl)
for i, (sequence, rfcv, length) in enumerate(train_dl):
inp, target = one_token_shift_sequences(sequence,
start_letter=START_LETTER,
gpu=CUDA)
nll = nn.NLLLoss()
hidden = None
loss = 0
for j in range(inp.shape[1]):
out, hidden = gen(inp[:, j], hidden)
loss += nll(out, target[:, j])
# remove all hidden loss computation. The structure of the code must be visible right here.
total_loss += loss.data.item() / inp.shape[1]
average_loss = total_loss / supervised_batch_count
print(f"Generator supervised epoch {epoch}, average_train_NLL = {average_loss:.4f}")
def train_generator_PG(gen, gen_opt, dis, num_batches):
The generator is trained using predictions from the discriminator. No supervised data is used.
batch_size = BATCH_SIZE * 2 # old implementation
sequences, logits = gen.generate_sequences(batch_size)
sequences, length = auto_batch_target_length(sequences)
dis_preds = dis (sequences, length)
adv_loss = PolicyGradientLoss() # log(P(y_t|Y_1:Y_{t-1})) * Q
loss = adv_loss(logits, dis_preds)
# Ensure that the discriminator is not trained: gen_opt only has generator loss, and dis_opt.zero_grad() must
# be called before the discriminator update.
pg_loss += loss.item()
loss.backward()
pg_loss /= num_batches
print(f"Generator policy gradient loss: {pg_loss:.4f}")
def train_discriminator(discriminator, dis_opt, real_train_set, real_valid_set, generator, d_steps, epochs):
"""
Training the discriminator on real_data_samples (positive) and generated samples from generator (negative).
Samples are drawn d_steps times, and the discriminator is trained for epochs epochs.
"""
# create validation set
fake_validation_sequences, _ = generator.generate_sequences(len(real_valid_set))
fake_valid_set = GeneratedDataset(fake_validation_sequences, random_truncate=True)
valid_set = ConcatDataset(real_valid_set, fake_valid_set)
valid_dl = DataLoader(dataset=valid_set, batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_collate,
num_workers=DATALOADER_NUM_WORKERS)
val_interval = 5
fake_training_sequences, _ = generator.generate_sequences(SUPERVISED_SIZE)
fake_train_set = GeneratedDataset(fake_training_sequences, random_truncate=True)
train_set = ConcatDataset(fake_train_set, real_train_set)
train_dl = DataLoader(dataset=train_set, batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_collate,
num_workers=DATALOADER_NUM_WORKERS)
for epoch in range(epochs):
sys.stdout.flush()
total_loss = 0
total_acc = 0
for i, (sequence, target, length) in enumerate(train_dl):
if CUDA:
sequence = sequence.cuda()
target = target.cuda()
out = discriminator(sequence, length)
mse = nn.MSELoss() # TODO fix
loss = mse(out, target)
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
total_loss += loss.item()
total_acc += torch.sum((out[0] > 0.5) == (target[0] > 0.5)).item()
avg_loss = total_loss / len(train_dl)
avg_acc = total_acc / len(train_set)
print(
f"Discriminator dstep {d_step}, epoch {epoch}, avg training loss = {avg_loss:.4f}, accuracy = {avg_acc:.4f}")
if epoch % val_interval == 0:
with torch.no_grad():
val_loss = 0
val_acc = 0
pos_count = 0
pos_stats = torch.zeros((3,))
neg_stats = torch.zeros((3,))
for i, (sequence, target, length) in enumerate(valid_dl):
if CUDA:
sequence, target, pos_stats, neg_stats = sequence.cuda(), target.cuda(), \
pos_stats.cuda(), neg_stats.cuda()
out = discriminator(sequence, length)
mse = nn.MSELoss(reduction="none")
loss = mse(out, target)
pos_stats += (loss * target[:, 0].unsqueeze(1)).sum(0)
neg_stats += (loss * (1 - target[:, 0].unsqueeze(1))).sum(0)
pos_count += target[:, 0].sum(0)
val_loss += loss.mean().item()
val_acc += torch.sum((out[0] > 0.5) == (target[0] > 0.5)).item()
val_loss /= len(valid_dl)
val_acc /= len(valid_set)
pos_stats /= pos_count
neg_stats /= len(valid_set) - pos_count
print('Discriminator validation average_loss = %.4f, train_acc = %.4f' % (
total_loss, val_acc))
print(' val_pos_class = %.4f, val_pos_valid = %.4f, val_pos_cov = %.4f' % (
pos_stats[0], pos_stats[1], pos_stats[2]))
print(' val_neg_class = %.4f, val_neg_valid = %.4f, val_neg_cov = %.4f' % (
neg_stats[0], neg_stats[1], neg_stats[2]))
def main1(load=False):
supervised_samples = torch.load(supervised_sequences_path)
supervised_targets = torch.load(supervised_val_cov_path)
zork_data = SupervisedData(supervised_samples, supervised_targets)
train_set = zork_data.training_set
valid_set = zork_data.validation_set
gen = Generator(GEN_EMBEDDING_DIM, GEN_HIDDEN_DIM, VOCAB_SIZE, MAX_SEQ_LEN, gpu=CUDA)
dis = Discriminator(DIS_EMBEDDING_DIM, DIS_HIDDEN_DIM, VOCAB_SIZE, MAX_SEQ_LEN, gpu=CUDA)
if load:
gen.load_state_dict(torch.load(pretrained_gen_path))
dis.load_state_dict(torch.load(pretrained_dis_path))
print('Starting Generator Supervised Training...')
train_generator_supervised(gen, gen_optimizer, train_set,
SUPERVISED_TRAIN_EPOCHS)
# torch.save(gen.state_dict(), pretrained_gen_path)
print('Starting Discriminator Training...')
train_discriminator(dis, dis_optimizer, train_set, valid_set, gen, 10, 3)
# torch.save(dis.state_dict(), pretrained_dis_path)
print('\nStarting Adversarial Training...')
print('\n--------\nEPOCH %d\n--------' % (epoch + 1))
print('\nAdversarial Training Generator : ')
train_generator_PG(gen, gen_optimizer, dis, 1)
# TRAIN DISCRIMINATOR
print('\nAdversarial Training Discriminator : ')
train_discriminator(dis, dis_optimizer, train_set, valid_set, gen, 2, 2)
# torch.save(dis.state_dict(), pretrained_dis_path)
# torch.save(gen.state_dict(), pretrained_gen_path)
# # gen.load_state_dict(torch.load(pretrained_gen_path))
# # dis.load_state_dict(torch.load(pretrained_dis_path))
# generate and save examples as desired
file1 = open(vocab_path, 'r')
Lines = file1.readlines()
# Strips the newline character
for line in Lines:
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
# print(vocab_list)
# # TEST SAMPLE QUALITY
# test_inp = gen.generate_sequences(batch_size=20)
# # h = dis.init_hidden(test_inp.size()[0])
# test_out = dis(test_inp)
# print(test_out)
#
# # SAVE SAMPLES TO FILES
# test_samples = gen.generate_sequences(batch_size=128)
# test_out = dis(test_samples)
# test_samples = test_samples.cpu()
# samples_list = test_samples.numpy()
#
# ## OUTPUT GENERATED SAMPLES
#
# for i in range(len(samples_list)):
# # print(test_out[i][2])
# # if test_out[i][2] > 0.1:
# with open(output_path + str(i) + '.txt', 'w') as vocab_file:
# for j in samples_list[i]:
# if (j == 1):
# break
# if (j > len(vocab_list)):
# break
# vocab_file.write('%s\n' % vocab_list[j])
#
# MAIN
if __name__ == '__main__':
main1()