"""
crazyflie_connection handles the actual interaction with the crazyflie.
The only reason it exists is to serve as an intermediary between the groundstation
and the crazyflie itself so that it can handle all interactions with the
cflib library.
"""

from datetime import datetime
from email.utils import localtime
from threading import Thread
from time import sleep, time
from typing import List
import time
import struct
import os

import cflib.crtp
from cflib.crazyflie import Crazyflie
from cflib.crazyflie.syncCrazyflie import SyncCrazyflie
from cflib.crazyflie.syncLogger import SyncLogger
from queue import Queue
#import groundstation_socket as gs
import uCartCommander
from groundstation_socket import MessageTypeID
from SetpointHandler import SetpointHandler, FlightMode
from cflib.crazyflie.log import LogConfig
import numpy as np
from LogfileHandler import LogfileHandler


class CrazyflieConnection:

    """
    Handles all interactions with cflib.
    """

    def __init__(self):
        

        # Get the start time, so that the timestamp returned to the user will
        # be in seconds since startup.
        self.start_time = time.time()

        self.scf = None
        self.is_connected = False
        self.param_callback_count = 0
        self.logging_configs = []
        self.logging_thread = None
        self.stop_thread = False
        self.setpoint_handler = SetpointHandler()
        self.logfile_handler = LogfileHandler()
        
        self.timestamp = 0

        # self.timer = QTimer()
        # self.timer.timeout.connect(self.update_plot)
        # self.timer.start(50)

    def connect(self, uri: str):
        """
        Handles connecting to a crazyflie. Bitcraze has excellent
        documentation on how to use the synchronous crazyflie object in order
        to send setpoints, set parameters or retrieve logging.
        :param uri: Radio channel
        """
        cflib.crtp.init_drivers()
        self.scf = SyncCrazyflie(uri, cf=Crazyflie(rw_cache='./cache'))
        self.scf.open_link()
        self.scf.wait_for_params()
        self.iC_connected = True
        print("Connect quad")
        #self.toc = self.get_TOC()

        # sets commander
        self.scf.cf.commander = uCartCommander.Commander(self.scf.cf)

        self.logging_configs = self.logfile_handler.read_all_active_blocks(self.scf)

        # connect the crazyflie commander to the setpoint handler
        # refresh the logging page so that it displays the toc
        # refresh the parameter page so that it displays the correct information
        self.setpoint_handler.setCommander(self.scf.cf.commander)
        self.setpoint_handler.startSetpointThread()

    def disconnect(self):
        """ Disconnect from crazyflie. """
        print("Disconnect quad")
        if self.is_connected:
            self.scf.close_link()
            self.scf = None
            self.is_connected = False

    def Debug():
        raise Exception
    def PacketLog():
        raise Exception
    def GetPacketLogs():
        raise Exception
    def Update():
        raise Exception
    def BeginUpdate():
        raise Exception
    def OverrideOuput(self, command): 
        """Sends all setpoints for a given amount of time"""
        mode = command['data'][0]
        time = command['data'][1:5] # Currently sent every 20ms by setpoint_handler may change
        thrust =  command['data'][5:9]
        pitch =  command['data'][9:13]
        roll =  command['data'][13:17]
        yaw =  command['data'][17:21]

        # Error Handling
        try:
            time = struct.unpack('f', bytes(time))[0]
            yaw = struct.unpack('f', bytes(yaw))[0]
            pitch = struct.unpack('f', bytes(pitch))[0]
            roll = struct.unpack('f', bytes(roll))[0]
            thrust = struct.unpack('f', bytes(thrust))[0]
        except ValueError:
            raise Exception

        # Check that setpoint_handler has a commander set
        if self.setpoint_handler.commander:
            # Lets crazyflie know we are about to start flying
            self.setpoint_handler.commander.start_flying()

            if mode == 0: # Stop?
                self.setpoint_handler.stopFlying()
            elif mode == 2: # Rate
                self.setpoint_handler.setRateMode()
            elif mode == 1: # Angle
                self.setpoint_handler.setAttitudeMode()
            elif mode == 3: # Mixed
                self.setpoint_handler.setMixedAttitudeMode()
            elif mode == 4: # Position
                raise Exception # Not implemented
            else :
                raise Exception
            if time == 0.0:
                time = 0.1
            if self.setpoint_handler.setpoint.pitch != pitch or self.setpoint_handler.setpoint.yaw != yaw or self.setpoint_handler.setpoint.roll != roll or self.setpoint_handler.setpoint.thrust != thrust:
                self.setpoint_handler.setSetpoint(yaw, pitch, roll, thrust)
            self.setpoint_handler.setpoint_time = time
            self.setpoint_handler.curr_time = 0
            #self.setpoint_handler.sendSetpoint()

    def GetNodeIds():
        raise Exception

    def SetParam(self, command):
        """ Set a crazyflie parameter value. """
        group = int.from_bytes(command['data'][0:2], 'little')
        name = int.from_bytes(command['data'][2:4], 'little')
        value = struct.unpack('f', bytes(command['data'][4:8]))[0]
        try:
            if self.scf.is_link_open():
                element = self.scf.cf.param.toc.get_element_by_id(name)
                full_name = element.group + "." + element.name
                cf = self.scf.cf

                cf.param.set_value(full_name, value)

        except AttributeError:
            print("Nothing connected")

    def GetParam(self, command, outputQueue: Queue): #group: str, name: str,
        """ Retrieve parameter value from crazyflie toc. """
        #Bytes 0 and 1 are node ID, ie group
        #Bytes 2 and 3 are node paramID, ie name
        print("Getting Param...")
        group = int.from_bytes(command['data'][0:2], 'little')
        name = int.from_bytes(command['data'][2:4], 'little')

        try:
            if self.scf.is_link_open():
                element = self.scf.cf.param.toc.get_element_by_id(name)
                full_name = element.group + "." + element.name
                actual = self.scf.cf.param.get_value(full_name)

        except AttributeError:
            actual = -1.234567
            pass

        data = bytearray()
        data += command['data'][0:2]
        data += command['data'][2:4]
        print(actual)
        actual = float(actual)
        flBytes = struct.pack('f', actual)
        data += flBytes
        responsedata = {
        "msg_type": (MessageTypeID.RESPPARAM_ID),
        "msg_id": command['msg_id'],
        "data_len": 8,
        "data": data
        }
        outputQueue.put(responsedata)

        #return -1.234567890  # 1234567890 should be pretty obvious that
        # something has gone wrong.
    def SetSource():
        raise Exception
    def GetSource():
        raise Exception
    def RespSource():
        raise Exception
    def GetOutput():
        raise Exception
    def GetNodes():
        raise Exception
    def AddNode():
        raise Exception
    def get_logging_toc(self):
        """ Retrieve entire logging table of contents. Used in order to
        display list in logging tab. """

        try:
            if self.scf.is_link_open():
                toc = self.scf.cf.log.toc.toc
                return toc
            else:
                return []
        except AttributeError:
            pass
        return []

    def get_param_toc(self):
        """ Retrieve entire param table of contents. Used in order to
        display params.  """

        try:
            if self.scf.is_link_open():
                toc = self.scf.cf.param.toc.toc
                return toc
        except AttributeError:
            pass
        return {}

    def GetLogFile(self, command, outputQueue: Queue): 
        print("Getting LogFile...")
        id = command['data'][0]
        if id == 0: # logdata?
            filename = self.logfile_handler.data_log_name
            data = bytearray()
            data += bytes("_" + filename, 'utf-8')
            responsedata = {
            "msg_type": (MessageTypeID.RESPLOGFILE_ID),
            "msg_id": command['msg_id'],
            "data_len": len(data),
            "data": data
            }
            outputQueue.put(responsedata)
        elif id == 1: # param toc
            params = self.get_param_toc()
            filename = self.logfile_handler.CopyTocToFile(params, True)
            data = bytearray()
            data += bytes("_" + filename, 'utf-8')
            responsedata = {
            "msg_type": (MessageTypeID.RESPLOGFILE_ID),
            "msg_id": command['msg_id'],
            "data_len": len(data),
            "data": data
            }
            outputQueue.put(responsedata)
        elif id == 2: # logging toc
            logs = self.get_logging_toc()
            filename = self.logfile_handler.CopyTocToFile(logs, False)
            data = bytearray()
            data += bytes("_" + filename, 'utf-8')
            responsedata = {
            "msg_type": (MessageTypeID.RESPLOGFILE_ID),
            "msg_id": command['msg_id'],
            "data_len": len(data),
            "data": data
            }
            outputQueue.put(responsedata)
        elif id == 3: 
            header = "_" + str(self.logfile_handler.header_id) + ":,time,"
            for config in self.logging_configs:
                for variable in config.variables:
                    header += variable.name + ","
            data = bytearray()
            data += bytes(header, 'utf-8')
            responsedata = {
            "msg_type": (MessageTypeID.RESPLOGFILE_ID),
            "msg_id": command['msg_id'],
            "data_len": len(data),
            "data": data
            }
            outputQueue.put(responsedata)
        elif id == 4: # state of test stand connection
            data = bytearray()
            data += bytes("_false", 'utf-8')
            responsedata = {
            "msg_type": (MessageTypeID.RESPLOGFILE_ID),
            "msg_id": command['msg_id'],
            "data_len": len(data),
            "data": data
            }
            outputQueue.put(responsedata)
        else :
            raise Exception
        
    def LogBlockCommand(self, command): 
        print("Log Block Command")
        id = command['data'][0]
        if id == 0:
            self.delete_log_blocks()
        elif id == 1:
            self.delete_log_blocks()
            self.logging_configs = self.logfile_handler.read_all_active_blocks(self.scf)
        elif id == 2:
            self.logging_configs = self.logfile_handler.read_all_active_blocks(self.scf)
        elif id == 3:
            block_id = command['data'][1]
            self.logging_configs.remove(self.logging_configs[block_id])
        elif id == 4:
            print(4)
            #self.enable_logging()
        elif id == 5:
            print(5)
            #self.disable_logging()
        elif id == 8:
            self.enable_logging()
            self.start_logging()
        elif id == 9:
            self.stop_logging()
            self.disable_logging()
            
        

    def enable_logging(self):
        """ Begins logging all configured logging blocks. This is used from
        the controls tab when hitting begin logging. """
        for i in range(0, len(self.logging_configs)):
            self.logging_configs[i].start()
        

    def disable_logging(self):
        """ Stops logging all configured logging blocks. This is used from
        the controls tab when hitting pause logging. """
        for i in range(0, len(self.logging_configs)):
            self.logging_configs[i].stop()
    
    def delete_log_blocks(self):
        self.logging_configs = []
    
    def start_logging(self):
        self.stop_thread = False
        self.logging_thread = Thread(target=self.continous_log)
        self.logging_thread.start()

    def stop_logging(self):
        self.stop_thread = True

    def continous_log(self):
        while not self.stop_thread:
            if self.logfile_handler.logging_queue.qsize() > 3:
                data = []
                for i in range(0, 12):
                    point = self.logfile_handler.logging_queue.get()
                    data.append(point)
                self.logfile_handler.write_data_points(data)