Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
from torchtext import data
from torchtext import datasets
import random
import time
import os
import copy
import csv
import numpy as np
import torch.nn.utils.rnn as rnn_utils
from torch.utils.data.sampler import SubsetRandomSampler
# for when samples are each in their own subfolder,
class zork_dataset(torch.utils.data.Dataset):
def __init__(self, folder_path, low_bound, high_bound):
self.path = folder_path
self.game_ns = []
self.step_ns = []
for i in range(low_bound,high_bound+1):
# self.game_ns.append(tmp_game)
self.step_ns.append(str(i))
#parse and store labels
full_path = os.path.join(folder_path, "Game"+str(i))
info_paths = []
info_paths.append(os.path.join(full_path, "rooms.c.gcov")) # CHANGE if i want more files
info_paths.append(os.path.join(full_path, "nrooms.c.gcov")) # CHANGE if i want more files
label = self.get_coverage(info_paths) # grab coverage information from files
coverage_score = 0.0
for i in label:
if i > 0.5:
coverage_score += 1.0
coverage_denom = max(1, len(label))
coverage_score = coverage_score / coverage_denom
#print(coverage_score)
output_validity = self.parse_output(os.path.join(full_path, "ZorkOut.gameOut"))
#print(output_validity)
full_label = []
full_label.append(1.0)
full_label.append(output_validity)
full_label.append(coverage_score)
label_path = os.path.join(full_path, "label.csv")
with open(label_path, 'w', newline='') as myfile:
wr = csv.writer(myfile)
wr.writerow(full_label)
self.len = len(self.step_ns)
def __getitem__(self, index):
item_path = os.path.join(self.path, "Game"+self.step_ns[index])
input_path = os.path.join(item_path, "input")
label_path = os.path.join(item_path, "label.csv")
x = np.loadtxt(input_path, delimiter='\n', dtype='str')
x = x.tolist()
y = np.loadtxt(label_path, delimiter=',', dtype=np.float32)
return x, y
def __len__(self):
return self.len
def get_coverage(self, filepaths):
label = []
for filepath in filepaths:
info_f = open(filepath, "r")
info_lines = info_f.readlines()
i = 0
for y in info_lines:
i += 1
temp_str = y.rstrip().split()[0][0:-1] #grab line count - ':'
if temp_str == '-':
continue
if temp_str == '#####':
label.append(0)
else:
label.append(int(temp_str.replace("*","")))
return label
def parse_output(self, filepath):
info_f = open(filepath, "r")
info_lines = info_f.readlines()
i = 0
total = 0
error_msgs = {">I can't see one here.", ">You can't go that way.", ">The window is closed.",
">I could't find anything.", ">You are empty handed.", ">You are in perfect health.",
">You cannot climb any higher."}
for y in info_lines:
ystr = y.rstrip()
if len(ystr) and ystr[0] == '>':
total += 1
if(not (ystr in error_msgs)):
i += 1
return i / total
print("Export complete.")