import numpy as np
import pyqtgraph as pg


class PlottingWindow(pg.PlotWidget):
    """
    Plotting window utilized a pyqtgraph plot widget in order to graph the
    logging variables that have been selected by the logging selection menu
    from the available logging config.
    """

    def __init__(self, parent=None, num_axes=5, width=5, height=4, dpi=100):
        """
        Initialize plotting window, axes, labels, and geometry.

        :param parent: Parent widget
        :param num_axes: Number of data axes
        :param width: Default width in pixels
        :param height: Default height in pixels
        :param dpi:
        """

        super().__init__(parent=parent, width=width, height=height, dpi=dpi)
        self.setLabel('bottom', 'Time (s)')
        self.setLabel('left', 'Amplitude')
        self.setLabel('top', 'Logging Variables Plot')
        self._num_axes = num_axes

        self.current_iter = 0

        # Set the background color of the plot widget to white
        self.setBackground('w')

        if num_axes > 0:
            # Set up the sine wave data and plot
            # store 1000 data points for each axis
            self.plot_data_ys = np.zeros((1000, num_axes))
            self.plot_data_xs = np.zeros((1000, num_axes))

            self.plot_curves = \
                [self.plot(self.plot_data_xs[:, i], self.plot_data_ys[:, i],
                           pen=pg.mkPen(pg.intColor(i), width=2))
                 for i in range(num_axes)]

    def update_plot(self, input_data: float, input_timestamp: float,
                    input_axis: int):
        """ Add new data to the plot from the right side. Utilizes a shift
        buffer to display data. Constrains the X range to the current time
        window. """

        # Shift the data and add a new point
        self.plot_data_ys[:-1, input_axis] = self.plot_data_ys[1:, input_axis]
        self.plot_data_ys[-1, input_axis] = input_data

        self.plot_data_xs[:-1, input_axis] = self.plot_data_xs[1:, input_axis]
        self.plot_data_xs[-1, input_axis] = input_timestamp

        self.plot_curves[input_axis].setData(self.plot_data_xs[:, input_axis],
                                             self.plot_data_ys[:, input_axis])

        # don't scroll at first. Display 10 seconds of data.
        if max(self.plot_data_xs[-1]) < 10:
            self.setXRange(0, 10)  # had padding=0 on both these ranges in
            # case that's needed
        else:
            self.setXRange(max(self.plot_data_xs[-1])-10,
                           max(self.plot_data_xs[-1]))