#include "node_pid.h"
#include <stdlib.h>
#include <math.h>

static double FLT_EPSILON = 0.0001;


struct pid_node_state {
    double prev_error; // Previous error
    double acc_error; // Accumulated error
    double last_filtered; // Last output from the filtered derivative
};

// The generic PID diagram. This function takes in pid inputs (CUR_POINT and SETOINT) and calculates the output "pid_correction"
// part based on those parameters.
//
//              +  ---    error    ------------------		P		+  ---			  ----------------------------
// setpoint ---> / sum \ --------->| Kp * error	    |--------------->/ sum \ -------->| output: "pid_correction" |
//				 \     /	|	   ------------------				 \	   /		  ----------------------------
//				   ---		|										   ---					||
//                -	^       |	   									+ ^  ^ +				||
//					|		|	   -------------------------------	  |	 |			------- \/------------
//					|		|----->| Ki * accumulated error * dt |----+	 |			|					 |
//					|		|	   -------------------------------	I	 |			|		SYSTEM		 |
//					|		|											 |			|					 |
//					|		|											 |			--------||------------
//					|		|											 |					||
//					|		|      ----------------------------------	 |					||
//					|		|----->| Kd * (error - last error) / dt |----+					||
//				    |			   ----------------------------------  D					||
//					|																		||
//					|															 -----------\/-----------
//					|____________________________________________________________| Sensor measurements: |
//																				 |	 "current point"	|
//																				 ------------------------
//
static void pid_computation(void *state, const double* params, const double *inputs, double *outputs) {
    struct pid_node_state* pid_state = (struct pid_node_state*)state;

    double P = 0.0, I = 0.0, D = 0.0;

    // calculate the current error
    double error = inputs[PID_SETPOINT] - inputs[PID_CUR_POINT];

    // Accumulate the error (if Ki is less than epsilon, rougly 0,
    // then reset the accumulated error for safety)
    if (fabs(params[PID_KI]) <= FLT_EPSILON) {
      pid_state->acc_error = 0;
    } else {
      pid_state->acc_error += error;
    }


    // Compute each term's contribution
    P = params[PID_KP] * error;
    I = params[PID_KI] * pid_state->acc_error * inputs[PID_DT];
    // Low-pass filter on derivative
    double change_in_error = error - pid_state->prev_error;
    double term1 = params[PID_ALPHA] * pid_state->last_filtered;
    double derivative = change_in_error / inputs[PID_DT];
    if (inputs[PID_DT] == 0) { // Divide by zero check
        derivative = 0;
    }
    double term2 = params[PID_KD] * (1.0f - params[PID_ALPHA]) * derivative;
    D = term1 + term2;
    pid_state->last_filtered = D; // Store filtered value for next filter iteration
    pid_state->prev_error = error; // Store the current error into the state

    outputs[PID_CORRECTION] = P + I + D; // Store the computed correction
}

// This function sets the accumulated error and previous error to 0
// to prevent previous errors from affecting output after a reset
static void reset_pid(void *state) {
    struct pid_node_state* pid_state = (struct pid_node_state*)state;
    pid_state->acc_error = 0;
    pid_state->prev_error = 0;
    pid_state->last_filtered = 0;
}


static const char* const in_names[3] = {"Cur point", "Setpoint", "dt"};
static const char* const out_names[1] = {"Correction"};
static const char* const param_names[4] = {"Kp", "Ki", "Kd", "alpha"};
const struct graph_node_type node_pid_type = {
        .input_names = in_names,
        .output_names = out_names,
        .param_names = param_names,
        .n_inputs = 3,
        .n_outputs = 1,
        .n_params = 4,
        .execute = pid_computation,
        .reset = reset_pid,
        .state_size = sizeof(struct pid_node_state),
        .type_id = BLOCK_PID
};

int graph_add_node_pid(struct computation_graph *graph, const char* name) {
    struct pid_node_state* node_state = malloc(sizeof(struct pid_node_state));
    if (sizeof(struct pid_node_state) && !node_state) {
        return -1; // malloc failed
    }
    return graph_add_node(graph, name, &node_pid_type, node_state);
}