import torch
from torchtext import data
from torchtext import datasets
import random
import time
import os
import copy
import csv
import numpy as np
import torch.nn.utils.rnn as rnn_utils
from torch.utils.data.sampler import SubsetRandomSampler

# for when samples are each in their own subfolder, 

class zork_dataset(torch.utils.data.Dataset):
    
    def __init__(self, folder_path, low_bound, high_bound):
        self.path = folder_path
        self.game_ns = []
        self.step_ns = []
        for i in range(low_bound,high_bound+1):
         #   self.game_ns.append(tmp_game)
            self.step_ns.append(str(i))

            #parse and store labels
            full_path = os.path.join(folder_path, "Game"+str(i))

            info_paths = []
            info_paths.append(os.path.join(full_path, "rooms.c.gcov")) # CHANGE if i want more files
            info_paths.append(os.path.join(full_path, "nrooms.c.gcov")) # CHANGE if i want more files
            label = self.get_coverage(info_paths) # grab coverage information from files
            coverage_score = 0.0
            for i in label:
                if i > 0.5:
                    coverage_score += 1.0
            coverage_denom = max(1, len(label))
            coverage_score = coverage_score / coverage_denom

            #print(coverage_score)
            output_validity = self.parse_output(os.path.join(full_path, "ZorkOut.gameOut"))
            #print(output_validity)

            full_label = []
            full_label.append(1.0)
            full_label.append(output_validity)
            full_label.append(coverage_score)

            label_path = os.path.join(full_path, "label.csv")
            with open(label_path, 'w', newline='') as myfile:
                wr = csv.writer(myfile)
                wr.writerow(full_label)

        self.len = len(self.step_ns)
        
    def __getitem__(self, index):
        item_path = os.path.join(self.path, "Game"+self.step_ns[index])
        input_path = os.path.join(item_path, "input")
        label_path = os.path.join(item_path, "label.csv")
        x = np.loadtxt(input_path, delimiter='\n', dtype='str')
        x = x.tolist()
        y = np.loadtxt(label_path, delimiter=',', dtype=np.float32)  
        return x, y
    
    def __len__(self):
        return self.len
    
    def get_coverage(self, filepaths):
        label = []
        for filepath in filepaths:
            info_f = open(filepath, "r")
            info_lines = info_f.readlines()
            i = 0
            for y in info_lines:
                i += 1
                temp_str = y.rstrip().split()[0][0:-1] #grab line count - ':'
                if temp_str == '-':
                    continue
                if temp_str == '#####':
                    label.append(0)
                else:
                    label.append(int(temp_str.replace("*","")))
        return label
    
    def parse_output(self, filepath):
        info_f = open(filepath, "r")
        info_lines = info_f.readlines()
        i = 0
        total = 0
        error_msgs = {">I can't see one here.", ">You can't go that way.", ">The window is closed.",
                      ">I could't find anything.", ">You are empty handed.", ">You are in perfect health.",
                     ">You cannot climb any higher."}
        for y in info_lines:
            ystr = y.rstrip()
            if len(ystr) and ystr[0] == '>':
                total += 1
                if(not (ystr in error_msgs)):
                    i += 1
        return i / total
print("Export complete.")