#!/usr/local/bin/python3.6

import sys
import random
from time import sleep
import struct

import serial

def create_msg(msg_type, msg_id, data):
    msg = bytes()
    msg += b'\xBE'
    msg += msg_type.to_bytes(2, 'little')
    msg += msg_id.to_bytes(2, 'little')
    msg += len(data).to_bytes(2, 'little')
    msg += data

    checksum = 0
    for b in msg:
        checksum ^= b
    msg += checksum.to_bytes(1, 'little')
    return msg

def create_test_packet(size=8):
    data = bytes((i % 256 for i in range(size)))
    return create_msg(1, 0, data)

def read_packet(ser, raw=False):
    header = ser.read(7)
    length = int.from_bytes(header[5:7], byteorder='little')
    data = ser.read(length)
    checksum = ser.read()
    if raw:
        return header + data + checksum
    else:
        return data

def parse_query_response(data):
    received_str = data[:-1].decode('ascii')
    if len(received_str) == 0:
        print("Timed out")
        return (-1,-1)
    return tuple(map(int, received_str.split(',')))

def query_received(ser):
    # Send request
    query_msg = create_msg(2, 0, b'')
    ser.write(query_msg)
    ser.flush()
    sleep(0.1)
    resp = read_packet(ser)
    return parse_query_response(resp)

def check_test(ser, n_sent, size_sent, old_status):
    new_n, new_size = query_received(ser)
    valid = True
    if old_status[0] + n_sent != new_n:
        print("Failure: Expected " + str(old_status[0] + n_sent) + " messages, responded with " + str(new_n))
        valid = False
    if old_status[1] + size_sent != new_size:
        print("Failure: Expected " + str(old_status[1] + size_sent) + " data, responded with " + str(new_size))
        valid = False
    return valid, (new_n, new_size)

def test_checksum(ser):
    print("Checking if received checksum is valid")
    # Send query packet
    ser.write(create_msg(2, 0, b''))
    raw_data = read_packet(ser, True)
    given_checksum = raw_data[-1]
    computed_checksum = 0
    for i in range(len(raw_data) - 1):
        computed_checksum ^= raw_data[i]

    valid = computed_checksum == given_checksum
    ret_status = query_received(ser)
    if not valid:
        print("Failure: Received checksum " + str(given_checksum) + ", expected " + str(computed_checksum))
    return valid, ret_status

def test_partial_packet(ser, cur_status, size=50):
    print("Checking partial packet...")
    to_send = create_test_packet(size)
    cutpoint = int(len(to_send) / 2)
    ser.write(to_send[:cutpoint])
    ser.flush()
    sleep(0.2)
    ser.write(to_send[cutpoint:])
    sleep(0.2)
    return check_test(ser, 1, size, cur_status)

def test_blast(ser, cur_status, N=50, size=32):
    print("Checking data blast... of size " + str(N) + " with size " + str(size))
    to_send = create_test_packet(size)
    for i in range(N):
        ser.write(to_send)
    ser.flush()
    sleep(0.5)
    return check_test(ser, N, N*size, cur_status)

def test_bad_checksum(ser, cur_status, size=30):
    print("Checking bad checksum...")
    to_send = create_test_packet(size)
    l_to_send = list(to_send)
    l_to_send[-1] = (l_to_send[-1] + 1) % 256
    to_send = bytes(l_to_send)
    ser.write(to_send)
    ser.flush()
    return check_test(ser, 0, 0, cur_status)

def test_get_set(ser):
    print("Checking Get/Set Commands...")
    passed = True
    for cntl_id in range(9):
        for const_id in range(4):
            to_set = random.random()
            # Construct the set command packet
            set_data = bytes([cntl_id, const_id])
            set_data += struct.pack("<f", to_set)
            set_packet = create_msg(7, 0, set_data)
            ser.write(set_packet)
            # Construct the get command packet
            get_packet = create_msg(8, 0, bytes([cntl_id, const_id]))
            ser.write(get_packet)
            # Get the just-set value
            resp = read_packet(ser)
            resp_cntl = resp[0]
            resp_cnst = resp[1]
            resp_val = struct.unpack('f', resp[2:6])[0]

            if resp_cntl != cntl_id or resp_cnst != const_id or abs(resp_val - to_set) > 1e-5:
                print("Failed get/set test. Expected controller " +
                    str(cntl_id) + ", constant " + str(const_id) + ", value " + str(to_set))
                print("    Received " + str(resp_cntl) + ", constant " + str(resp_cnst) + ", value " + str(resp_val))
                passed = False

    ret_status = query_received(ser)
    return passed, ret_status

if __name__ == '__main__':
    with serial.Serial('/dev/ttyUSB0', 921600, timeout=5) as ser:
        ser.reset_input_buffer()
        while (ser.in_waiting != 0):
            ser.read()
        status = query_received(ser)

        passed, status = test_partial_packet(ser, status)
        passed, status = test_blast(ser, status)
        passed, status = test_blast(ser, status, 150, 80)
        passed, status = test_bad_checksum(ser, status)
        passed, status = test_checksum(ser)
        passed, status = test_get_set(ser)