import socket
import time
import uart_stress_tests

TCP_IP = "192.168.1.1"
TCP_PORT = 8080

send_delay = 0.001
TEST_ITERS = 100000
REPORT_STEP = int(TEST_ITERS / 10)
payload_size = 28

received = []
def read_n(n_to_read):
    global received
    while len(received) < n_to_read:
        try:
            just_received = sock.recv(1024)
            if len(just_received) == 0:
                print("Socket broken")
                quit()
            received.extend(just_received)
        except:
            pass
    to_ret = bytes(received[0:n_to_read])
    received = received[n_to_read:]
    return to_ret

def read_packet(raw=False):
    header = read_n(7)
    length = int.from_bytes(header[5:7], byteorder='little')
    data = read_n(length)
    checksum = read_n(1)
    if raw:
        return header + data + checksum
    else:
        return data

def get_quad_status():
    print("Getting quad status...")
    query_msg = uart_stress_tests.create_msg(2, 0, b'')
    sock.send(query_msg)
    resp = read_packet()
    return uart_stress_tests.parse_query_response(resp)
    print("Got quad status")


if __name__ == '__main__':
    # connect
    print("Connecting...")
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    sock.settimeout(2)
    try:
        sock.connect((TCP_IP, TCP_PORT))
        sock.setblocking(True)
    except:
        print("Failed to connect")
        quit()
    print("connected")

    initial_status = get_quad_status()

    message = uart_stress_tests.create_test_packet(payload_size)
    overall_start_time = time.perf_counter()
    for i in range(TEST_ITERS):
        if i % REPORT_STEP == 0:
            print("Sent {} messages".format(i))
        start_time = time.perf_counter()
        sock.send(message)
        # Busy waiting
        # to_sleep_for = send_delay - (time.time() - start_time)
        # if to_sleep_for > 1e-6:
        #     time.sleep(to_sleep_for)
        while (time.perf_counter() - start_time <= send_delay):
            pass
    # Reporting
    avg_time = (time.perf_counter() - overall_start_time) / TEST_ITERS
    print("Average send time was {}".format(avg_time))

    after_status = get_quad_status()
    diff_messages = after_status[0] - initial_status[0]
    diff_payload = after_status[1] - initial_status[1]
    print("Sent {} messages, total payload of {}".format(TEST_ITERS, TEST_ITERS*payload_size))
    print("Recv {} messages, total payload of {}".format(diff_messages, diff_payload))
    print("Lost {} messages, {} bytes".format(TEST_ITERS - diff_messages, (TEST_ITERS*payload_size) - diff_payload))