import torch from torchtext import data from torchtext import datasets import random import time import os import copy import csv import numpy as np import torch.nn.utils.rnn as rnn_utils from torch.utils.data.sampler import SubsetRandomSampler # for when samples are each in their own subfolder, class zork_dataset(torch.utils.data.Dataset): def __init__(self, folder_path, low_bound, high_bound): self.path = folder_path self.game_ns = [] self.step_ns = [] for i in range(low_bound,high_bound+1): # self.game_ns.append(tmp_game) self.step_ns.append(str(i)) #parse and store labels full_path = os.path.join(folder_path, "Game"+str(i)) info_paths = [] info_paths.append(os.path.join(full_path, "rooms.c.gcov")) # CHANGE if i want more files info_paths.append(os.path.join(full_path, "nrooms.c.gcov")) # CHANGE if i want more files label = self.get_coverage(info_paths) # grab coverage information from files coverage_score = 0.0 for i in label: if i > 0.5: coverage_score += 1.0 coverage_denom = max(1, len(label)) coverage_score = coverage_score / coverage_denom #print(coverage_score) output_validity = self.parse_output(os.path.join(full_path, "ZorkOut.gameOut")) #print(output_validity) full_label = [] full_label.append(1.0) full_label.append(output_validity) full_label.append(coverage_score) label_path = os.path.join(full_path, "label.csv") with open(label_path, 'w', newline='') as myfile: wr = csv.writer(myfile) wr.writerow(full_label) self.len = len(self.step_ns) def __getitem__(self, index): item_path = os.path.join(self.path, "Game"+self.step_ns[index]) input_path = os.path.join(item_path, "input") label_path = os.path.join(item_path, "label.csv") x = np.loadtxt(input_path, delimiter='\n', dtype='str') x = x.tolist() y = np.loadtxt(label_path, delimiter=',', dtype=np.float32) return x, y def __len__(self): return self.len def get_coverage(self, filepaths): label = [] for filepath in filepaths: info_f = open(filepath, "r") info_lines = info_f.readlines() i = 0 for y in info_lines: i += 1 temp_str = y.rstrip().split()[0][0:-1] #grab line count - ':' if temp_str == '-': continue if temp_str == '#####': label.append(0) else: label.append(int(temp_str.replace("*",""))) return label def parse_output(self, filepath): info_f = open(filepath, "r") info_lines = info_f.readlines() i = 0 total = 0 error_msgs = {">I can't see one here.", ">You can't go that way.", ">The window is closed.", ">I could't find anything.", ">You are empty handed.", ">You are in perfect health.", ">You cannot climb any higher."} for y in info_lines: ystr = y.rstrip() if len(ystr) and ystr[0] == '>': total += 1 if(not (ystr in error_msgs)): i += 1 return i / total print("Export complete.")