#include "test.h"


#include "computation_graph.h"
#include "graph_blocks/node_add.h"
#include "graph_blocks/node_mult.h"
#include "graph_blocks/node_constant.h"
#include "graph_blocks/node_gain.h"
#include "graph_blocks/node_accumulator.h"

#define GRAPH_TEST_EPS 0.00001

static int nequal(double val1, double val2) {
    if (fabs(val1 - val2) < GRAPH_TEST_EPS) {
        return 0;
    }
    return -1;
}

int graph_test_one_add() {
    struct computation_graph *graph = create_graph();
    int block = graph_add_node_add(graph, "Add");
    int cblock3 = graph_add_node_const(graph, "3");
    graph_set_param_val(graph, cblock3, CONST_SET, 3);
    int cblock4 = graph_add_node_const(graph, "4");
    graph_set_param_val(graph, cblock4, CONST_SET, 4);
    graph_set_source(graph, block, ADD_SUMMAND1, cblock3, CONST_VAL);
    graph_set_source(graph, block, ADD_SUMMAND2, cblock4, CONST_VAL);
    int to_compute_for[1] = {block};
    graph_compute_nodes(graph, to_compute_for, 1);
    double result = graph_get_output(graph, block, ADD_SUM);
    return nequal(result, 7);
}


int graph_test_circular_runs() {
    struct computation_graph *graph = create_graph();
    int gain1 = graph_add_node_gain(graph, "gain1");
    int gain2 = graph_add_node_gain(graph, "gain2");
    graph_set_source(graph, gain2, GAIN_INPUT, gain1, GAIN_RESULT);
    graph_set_source(graph, gain1, GAIN_INPUT, gain2, GAIN_RESULT);
    int to_compute_for[1] = {gain2};
    graph_compute_nodes(graph, to_compute_for, 1);
    // If no infinite loop, then success. Value is undefined for circular graphs
    return 0;
}

int graph_test_circular_resets() {
    struct computation_graph *graph = create_graph();
    int acum1 = graph_add_node_accum(graph, "accumulator1");
    int acum2 = graph_add_node_accum(graph, "accumulator2");
    graph_set_source(graph, acum2, ACCUM_IN, acum1, ACCUMULATED);
    graph_set_source(graph, acum1, ACCUM_IN, acum2, ACCUMULATED);
    return 0; // Passes if no infinite loop
}

// Tests the accumulator block, thereby testing reset and state changes
int graph_test_accumulator() {
    struct computation_graph *graph = create_graph();
    int cblock = graph_add_node_const(graph, "const");
    int acum_b = graph_add_node_accum(graph, "accumulator");
    graph_set_source(graph, acum_b, ACCUM_IN, cblock, CONST_VAL);

    int to_compute_for[1] = {acum_b};
    graph_set_param_val(graph, cblock, CONST_SET, 3);
    graph_compute_nodes(graph, to_compute_for, 1);
    graph_set_param_val(graph, cblock, CONST_SET, 8);
    graph_compute_nodes(graph, to_compute_for, 1);
    graph_set_param_val(graph, cblock, CONST_SET, -2);
    graph_compute_nodes(graph, to_compute_for, 1);

    double result = graph_get_output(graph, acum_b, ACCUMULATED);
    if (nequal(result, 9)) {
        printf("graph_test_accumulator failed on step 1, equals %f\n", result);
        return -1;
    }

    // Test reset on source set
    int gain_b = graph_add_node_gain(graph, "Gain");
    graph_set_param_val(graph, gain_b, GAIN_GAIN, 1);
    graph_set_source(graph, gain_b, GAIN_INPUT, acum_b, ACCUMULATED);
    to_compute_for[0] = gain_b;
    graph_compute_nodes(graph, to_compute_for, 1);
    result = graph_get_output(graph, gain_b, GAIN_RESULT);
    if (nequal(result, -2)) {
        printf("graph_test_accumulator failed on step 2\n");
        return -2;
    }
    return 0;
}

// Tests that a block will only execute once per compute,
// even if its output is connected to multiple inputs
int graph_test_single_run() {
    struct computation_graph *graph = create_graph();
    int acum_b = graph_add_node_accum(graph, "accumulator");
    int add_block = graph_add_node_add(graph, "Add");
    int cblock = graph_add_node_const(graph, "const");
    graph_set_param_val(graph, cblock, CONST_SET, 2);


    graph_set_source(graph, acum_b, ACCUM_IN, cblock, CONST_VAL);
    graph_set_source(graph, add_block, ADD_SUMMAND1, acum_b, ACCUMULATED);
    graph_set_source(graph, add_block, ADD_SUMMAND2, acum_b, ACCUMULATED);

    int to_compute_for[1] = {add_block};
    graph_compute_nodes(graph, to_compute_for, 1);
    double result = graph_get_output(graph, add_block, ADD_SUM);
    return nequal(result, 4);
}

// Tests that upon connection of a second child, a block will not reset
int graph_test_reset_rules() {
    struct computation_graph *graph = create_graph();
    int cblock = graph_add_node_const(graph, "5");
    graph_set_param_val(graph, cblock, CONST_SET, 5);
    int acum_b = graph_add_node_accum(graph, "accumulator");
    int gain1 = graph_add_node_gain(graph, "gain1");
    graph_set_param_val(graph, gain1, GAIN_GAIN, 1);

    graph_set_source(graph, gain1, GAIN_INPUT, acum_b, ACCUMULATED);
    graph_set_source(graph, acum_b, ACCUM_IN, cblock, CONST_VAL);
    int to_compute_for[1] = {gain1};
    graph_compute_nodes(graph, to_compute_for, 1);
    // state of acum_b is now 5

    int gain2 = graph_add_node_gain(graph, "gain2");
    graph_set_param_val(graph, gain2, GAIN_GAIN, 1);
    // Connect gain 2, and accumulator should not get reset
    graph_set_source(graph, gain2, GAIN_INPUT, acum_b, ACCUMULATED);
    to_compute_for[0] = gain2;
    graph_compute_nodes(graph, to_compute_for, 1);
    double result = graph_get_output(graph, gain2, GAIN_RESULT);
    // Equals 5 and not 10 because the inputs to the accumulator did not change,
    // so it didn't run again'
    return nequal(result, 5);
}

int graph_test_self_loop() {
    struct computation_graph *graph = create_graph();
    int gain1 = graph_add_node_gain(graph, "gain1");
    graph_set_source(graph, gain1, GAIN_INPUT, gain1, GAIN_RESULT);
    int to_compute_for[1] = {gain1};
    graph_compute_nodes(graph, to_compute_for, 1);
    return 0;
}

int graph_test_update_rules() {
    struct computation_graph *graph = create_graph();
    int cblock = graph_add_node_const(graph, "const");
    int acum_b = graph_add_node_accum(graph, "accumulator");
    graph_set_source(graph, acum_b, ACCUM_IN, cblock, CONST_VAL);

    graph_set_param_val(graph, cblock, CONST_SET, 3);
    int to_compute_for[1] = {acum_b};
    graph_compute_nodes(graph, to_compute_for, 1);
    graph_compute_nodes(graph, to_compute_for, 1);

    double result = graph_get_output(graph, acum_b, ACCUMULATED);
    return nequal(result, 3);
}

/*
C1 --->| accum_b1 --->|
                      | Add --->
C2 --->| accum_b2 --->|
*/
int graph_test_update_propagation() {
    struct computation_graph *graph = create_graph();
    int cblock1 = graph_add_node_const(graph, "const1");
    int cblock2 = graph_add_node_const(graph, "const2");
    int accum_b1 = graph_add_node_accum(graph, "accumulator1");
    int accum_b2 = graph_add_node_accum(graph, "accumulator2");
    int add_b = graph_add_node_add(graph, "add");
    graph_set_source(graph, accum_b1, ACCUM_IN, cblock1, CONST_VAL);
    graph_set_source(graph, accum_b2, ACCUM_IN, cblock2, CONST_VAL);
    graph_set_source(graph, add_b, ADD_SUMMAND1, accum_b1, ACCUMULATED);
    graph_set_source(graph, add_b, ADD_SUMMAND2, accum_b2, ACCUMULATED);
    graph_set_param_val(graph, cblock1, CONST_SET, 5);
    graph_set_param_val(graph, cblock2, CONST_SET, 2);
    int to_compute_for[] = {add_b};
    graph_compute_nodes(graph, to_compute_for, 1);
    double result1 = graph_get_output(graph, add_b, ADD_SUM);
    graph_set_param_val(graph, cblock1, CONST_SET, 1);
    graph_compute_nodes(graph, to_compute_for, 1);
    double result2 = graph_get_output(graph, add_b, ADD_SUM);

    return nequal(result1, 7) || nequal(result2, 8);
}

/*
Tests for a really subtle edge case
If a node is disconnected from the computation path, then it will not get computed, even if it is "updated"
After computation, nodes get their updated flag cleared. It used to clear all nodes, even if they weren't processed (now it only clears if they were processed)
This caused problems, because if a node had its output to two things, then it wouldn't get marked updated by the reset propagation.
Since it didn't get marked as "updated" when it got connected to the computation path, and it had its original "updated" flag cleared,
the node would never get set.
*/
int graph_test_update_disconnected() {
    printf("\n\n---------\n");
    struct computation_graph *graph = create_graph();
    int d_block = graph_add_node_const(graph, "const1");
    int gain_block = graph_add_node_gain(graph, "gain");
    int gain2_block = graph_add_node_gain(graph, "gain2");
    int const_b = graph_add_node_const(graph, "const2");
    graph_set_source(graph, gain_block, GAIN_INPUT, const_b, CONST_VAL);
    graph_set_source(graph, gain2_block, GAIN_INPUT, d_block, CONST_VAL); // We need this so d_block doesn't get updated
    graph_set_param_val(graph, gain_block, GAIN_GAIN, 2);
    graph_set_param_val(graph, d_block, CONST_SET, 1.2345); // Make sure this gets set even if disconnected

    int to_compute_for[] = {gain_block};
    graph_compute_nodes(graph, to_compute_for, 1);
    graph_set_source(graph, gain_block, GAIN_INPUT, d_block, CONST_VAL);
    graph_compute_nodes(graph, to_compute_for, 1);

    double set_val = graph_get_output(graph, gain_block, GAIN_RESULT);
    return nequal(set_val, 2*1.2345);
}

int main() {
    int success = 0;
    test(graph_test_one_add, "Test adding 2 numbers");
    test(graph_test_circular_runs, "Test computing cycles");
    test(graph_test_circular_resets, "Test resetting cycles");
    test(graph_test_accumulator, "Test accumulator (state)");
    test(graph_test_single_run, "Test that blocks only get executed once");
    test(graph_test_reset_rules, "Tests that already connected blocks don't get reset");
    test(graph_test_self_loop, "Tests that a self-loop computation terminates");
    test(graph_test_update_rules, "Tests that nodes only update when their inputs change");
    test(graph_test_update_propagation, "Tests that updates propagate only to their children");
    test(graph_test_update_disconnected, "Tests that nodes get executed when updated, even if disconnected");
    test_summary();
}