Skip to content
Snippets Groups Projects
zork_dataset.py 3.48 KiB
Newer Older
zaglanz's avatar
zaglanz committed
import torch
from torchtext import data
from torchtext import datasets
import random
import time
import os
import copy
import csv
import numpy as np
import torch.nn.utils.rnn as rnn_utils
from torch.utils.data.sampler import SubsetRandomSampler

# for when samples are each in their own subfolder, 

class zork_dataset(torch.utils.data.Dataset):
    
    def __init__(self, folder_path, low_bound, high_bound):
        self.path = folder_path
        self.game_ns = []
        self.step_ns = []
        for i in range(low_bound,high_bound+1):
         #   self.game_ns.append(tmp_game)
            self.step_ns.append(str(i))

            #parse and store labels
            full_path = os.path.join(folder_path, "Game"+str(i))

            info_paths = []
            info_paths.append(os.path.join(full_path, "rooms.c.gcov")) # CHANGE if i want more files
            info_paths.append(os.path.join(full_path, "nrooms.c.gcov")) # CHANGE if i want more files
            label = self.get_coverage(info_paths) # grab coverage information from files
            coverage_score = 0.0
            for i in label:
                if i > 0.5:
                    coverage_score += 1.0
            coverage_denom = max(1, len(label))
            coverage_score = coverage_score / coverage_denom

            #print(coverage_score)
            output_validity = self.parse_output(os.path.join(full_path, "ZorkOut.gameOut"))
            #print(output_validity)

            full_label = []
            full_label.append(1.0)
            full_label.append(output_validity)
            full_label.append(coverage_score)

            label_path = os.path.join(full_path, "label.csv")
            with open(label_path, 'w', newline='') as myfile:
                wr = csv.writer(myfile)
                wr.writerow(full_label)

        self.len = len(self.step_ns)
        
    def __getitem__(self, index):
        item_path = os.path.join(self.path, "Game"+self.step_ns[index])
        input_path = os.path.join(item_path, "input")
        label_path = os.path.join(item_path, "label.csv")
        x = np.loadtxt(input_path, delimiter='\n', dtype='str')
        x = x.tolist()
        y = np.loadtxt(label_path, delimiter=',', dtype=np.float32)  
        return x, y
    
    def __len__(self):
        return self.len
    
    def get_coverage(self, filepaths):
        label = []
        for filepath in filepaths:
            info_f = open(filepath, "r")
            info_lines = info_f.readlines()
            i = 0
            for y in info_lines:
                i += 1
                temp_str = y.rstrip().split()[0][0:-1] #grab line count - ':'
                if temp_str == '-':
                    continue
                if temp_str == '#####':
                    label.append(0)
                else:
                    label.append(int(temp_str.replace("*","")))
        return label
    
    def parse_output(self, filepath):
        info_f = open(filepath, "r")
        info_lines = info_f.readlines()
        i = 0
        total = 0
        error_msgs = {">I can't see one here.", ">You can't go that way.", ">The window is closed.",
                      ">I could't find anything.", ">You are empty handed.", ">You are in perfect health.",
                     ">You cannot climb any higher."}
        for y in info_lines:
            ystr = y.rstrip()
            if len(ystr) and ystr[0] == '>':
                total += 1
                if(not (ystr in error_msgs)):
                    i += 1
        return i / total
print("Export complete.")