#include <stdio.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <stdlib.h>
#include <netdb.h>
#include <string.h>
#include <sys/time.h>
#include <pthread.h>
#include <unistd.h>

#include "ccitt16.h"
#include "AddCongestion.h"

// global variables
struct sockaddr_in remoteaddr;
int clisock;
int TEXT_BUFFER_SIZE = 1200;
int SEQUENCE_NUMBER_START = 1000;
char * text_buffer;
int BUFFER_SIZE = 2;
char * RECEIVER_IP;
int PACKET_SIZE = 6;
int ssthresh;
int cwnd;
unsigned short window_start;
unsigned short last_sent;
double p;
int dupAcks;
int nonDupAcks;
pthread_mutex_t windowLock;
FILE * outputFile;
int done = 0;

int RTO_TIMEOUT = 3;
void packet_retransmit(unsigned short sequence_number);

// function that sends the packet of the given sequence_number
void send_packet(unsigned short pktnum) {
  printf("Sending packet %hu\n", pktnum);
  char * data_chars = text_buffer + (pktnum - SEQUENCE_NUMBER_START)*2;
  unsigned char data_sequence[7];
  
  *((unsigned short *) data_sequence) = pktnum;
  *(data_sequence + 2) = *data_chars;
  *(data_sequence + 3) = *(data_chars + 1);
  
  // calculate CRC
  short crc_code = calculate_CCITT16(data_sequence, 4, GENERATE_CRC);
  *((short *) (data_sequence + 4)) = crc_code;
  // null terminate the array
  *(data_sequence + 6) = 0;
  
  // swap crc bytes to appease endianness
  unsigned char temp = data_sequence[4];
  data_sequence[4] = data_sequence[5];
  data_sequence[5] = temp;
  
  // introduce error
  AddCongestion(data_sequence, p);
  
  //send info to receiver
  if ((write(clisock, data_sequence, PACKET_SIZE)) < 0) {
    perror("write() error:\n");
    exit(1);
  }
}

// function run by threads after sending a packet - runs the Tahoe protocol
void * packet_timeout(void * sequence_number_pointer) {
  // wait for a timeout
  sleep(RTO_TIMEOUT);
  
  unsigned short * sequence_number = (unsigned short *) sequence_number_pointer;
  
  pthread_mutex_lock(&windowLock);
  // check to see if ACK has arrived - exit if it has
  if (window_start > * sequence_number) {
    free(sequence_number_pointer);
    pthread_mutex_unlock(&windowLock);
    return;
  }
  pthread_mutex_unlock(&windowLock);
  
  // otherwise, retransmit the packet  
  packet_retransmit(* sequence_number);
}

// function retransmits the packet of the given sequence_number, then updates cwnd and ssthresh
void packet_retransmit(unsigned short sequence_number) {
  send_packet(sequence_number);
  
  // after sending, create a new thread to wait 3 seconds for that packet
  pthread_t timeout_thread;
  unsigned short * seqNum = malloc(sizeof(unsigned short));
  * seqNum = sequence_number;
  pthread_create(&timeout_thread, NULL, packet_timeout, (void *) seqNum);
  pthread_detach(timeout_thread);
  
  pthread_mutex_lock(&windowLock);
  // update ssthresh and cwnd
  ssthresh = cwnd / 2;
  cwnd = 1;
  pthread_mutex_unlock(&windowLock);
}

// function run by an output_thread - it prints the value of cwnd every RTT
void * output_cwnd() {
  int RTT = 0;
  // done flag set after transmission is complete
  while(!done) {
    pthread_mutex_lock(&windowLock);
    fprintf(outputFile, "%d %d\n", RTT, cwnd);
    pthread_mutex_unlock(&windowLock);
    RTT++;
    // length of RTT is 1 second
    sleep(1);
  }
  fclose(outputFile);
}

int main(int argc, char * argv[]) {
  
  FILE * fp;
  int i, j;
  char * packet;
  unsigned short sequence_number = SEQUENCE_NUMBER_START;
  char data_chars[3];
  unsigned short ackNum;
  dupAcks = 0;
  nonDupAcks = 0;
  window_start = SEQUENCE_NUMBER_START;
  last_sent = SEQUENCE_NUMBER_START;
  ssthresh = 16;
  cwnd = 1;
  
  // create mutex for windows information variables
  if (pthread_mutex_init(&windowLock, NULL) != 0) {
    perror("pthread_mutex_init() error:\n");
    exit(1);
  }
  
  // open output file
  if ((outputFile = fopen("output.txt", "w")) < 0) {
    perror("fopen() error:\n");
    exit(1);
  }
  // thread will be used to output cwnd
  pthread_t output_thread;
  
  // allocate space for a buffer to hold info from input.txt
  text_buffer = malloc(TEXT_BUFFER_SIZE * sizeof(char));
  char * readBuffer = malloc(BUFFER_SIZE * sizeof(char));
  
  // open file for reading
  if ((fp = fopen("input.txt", "r")) < 0 ) {
    perror("fopen() error:\n");
    free(text_buffer);
    free(readBuffer);
    exit(1);
  }
  // grab text from file
  while (fscanf(fp, "%s", text_buffer) != EOF) {}
  
  fclose(fp);
  
  // create the receiver address
  remoteaddr.sin_family = PF_INET;
  // random port number
  remoteaddr.sin_port = htons(54397);

  // report if no addresses are given
  if (argc == 1) {
    printf("No IP address arguments given.\n");
    free(text_buffer);
    free(readBuffer);
    exit(1);
  }
  
  // grab ip address from arguments    
  RECEIVER_IP = argv[1];
  // grab p value from arguments
  if (argc == 3) {
    p = atof(argv[2]);
  } else {
    p = 0;
  }
  
  printf("%f\n", p);
  // print value of p to file
  fprintf(outputFile, "Value of p (BER): %f\nRTT cwnd\n", p);
    
  // create a socket
  if ((clisock = socket(PF_INET, SOCK_STREAM, 0)) < 0) {
    perror("socket() error:\n");
    free(text_buffer);
    free(readBuffer);
    exit(1);
  }

  // set the receiver address
  remoteaddr.sin_addr.s_addr = inet_addr(RECEIVER_IP);

  // connect to the receiver
  if ((connect(clisock, (struct sockaddr *) &remoteaddr, sizeof(remoteaddr))) < 0) {
    printf("connect() error for address %s:\n", RECEIVER_IP);
    perror("");
    free(text_buffer);
    free(readBuffer);
    exit(1);
  }
  
  // start output thread
  pthread_create(&output_thread, NULL, &output_cwnd, NULL);
  
  // loop until the last packet is sent
  while(last_sent < SEQUENCE_NUMBER_START + (strlen(text_buffer)/2)) {
    
    pthread_mutex_lock(&windowLock);
    // send everything allowed in the window
    for (sequence_number = last_sent; sequence_number < window_start + cwnd; sequence_number++) {
      send_packet(sequence_number);
      
      // after sending, create a new thread to wait 3 seconds for that packet
      pthread_t timeout_thread;
      unsigned short * seqNum = malloc(sizeof(unsigned short));
      * seqNum = sequence_number;
      pthread_create(&timeout_thread, NULL, &packet_timeout, (void *) seqNum);
      pthread_detach(timeout_thread);
      
      last_sent++;
    }
    pthread_mutex_unlock(&windowLock);    

    // wait for an ACK
    if ((read(clisock, readBuffer, BUFFER_SIZE)) < 0) {
      perror("read() error");
      free(text_buffer);
      free(readBuffer);
      exit(1);
    }
    
    // save the ACK
    ackNum = *((unsigned short *) readBuffer);
    printf("ACK Received: %hu\n", *((unsigned short *) readBuffer));
    
    pthread_mutex_lock(&windowLock);
    // check for duplicate ACKs
    if (ackNum == window_start) {
      dupAcks++;
      nonDupAcks = 0;
    }
    // non-dup ack
    else {
      dupAcks = 0;
      window_start = ackNum;
      // SS phase
      if (cwnd < ssthresh) {
        cwnd++;
      }
      // CA phase
      else {
        nonDupAcks++;
        if (nonDupAcks == cwnd) {
          cwnd++;
          nonDupAcks = 0;
        }
      }
    }
    pthread_mutex_unlock(&windowLock);
    // if it is a 3rd dup ACK, retransmit
    if (dupAcks == 3) {
      unsigned short seqNum_dup = ackNum;
      packet_retransmit(seqNum_dup);
    }
    
  }
  
  // after all packets are initially sent, wait for receiver to ACK everything
  while(ackNum < last_sent) {
    // wait for an ACK
    if ((read(clisock, readBuffer, BUFFER_SIZE)) < 0) {
      perror("read() error");
      free(text_buffer);
      free(readBuffer);
      exit(1);
    }
    // save the ACK
    ackNum = *((unsigned short *) readBuffer);
    printf("ACK Received: %hu\n", *((unsigned short *) readBuffer));
  }
  
  // communication between receiver and sender ends here
  
  // close connection
  if ((close(clisock)) < 0) {
    perror("close() error:\n");
    free(text_buffer);
    free(readBuffer);
    exit(1);
  }
  
  free(text_buffer);
  free(readBuffer);
  
  // alert output file to end
  done = 1;
  
  // wait for output thread to close
  pthread_join(output_thread, NULL);
  
  pthread_mutex_destroy(&windowLock);
  
  return 0;
}