"""Class to draw a plot of a 2D rectilinear grid frame and its i-j triad."""

# 1. Standard Python modules
import math

# 2. Third party modules
from matplotlib.backends.backend_qt5agg import FigureCanvas
from matplotlib.figure import Figure
from matplotlib.patches import Rectangle
from matplotlib.ticker import AutoLocator

# 3. Aquaveo modules

# 4. Local modules


class TwoDRectGridPlot:
    """Class to draw a plot of a 2D rectilinear grid frame and its i-j triad."""
    def __init__(self, cogrid, grid_name):
        """Constructor.

        Args:
            cogrid (CoGrid): The 2D rectilinear constrained CoGrid
            grid_name (str): The name of the CGrid. Isn't store on the CoGrid have to get from the data_object
        """
        self._cogrid = cogrid
        self._grid_name = grid_name
        self._figure = None
        self._canvas = None
        self._ax = None
        self._origin_x = 0.0
        self._origin_y = 0.0
        self._angle = 0.0
        self._i_width = 1.0
        self._j_height = 1.0

    def _get_grid_info(self):
        """Initialize the variables we need for plotting from the CoGrid."""
        if self._cogrid is None:
            return  # If no CoGrid, we'll just make an empty plot
        origin = self._cogrid.origin
        self._origin_x = origin[0]
        self._origin_y = origin[1]
        self._angle = self._cogrid.angle
        # We don't want to wrap the C++ arrays multiple times, so get everything we need here.
        locations_x = self._cogrid.locations_x
        locations_y = self._cogrid.locations_y
        self._i_width = locations_x[-1]  # Given as a distance from the origin, so last is total width.
        self._j_height = locations_y[-1]  # Given as a distance from the origin, so last is total height.
        self._num_i = len(locations_x) - 1
        self._num_j = len(locations_y) - 1

    def _initialize_plot(self):
        """Initializes grid info and creates the matplotlib canvas widget."""
        self._get_grid_info()
        self._figure = Figure()
        self._figure.set_tight_layout(True)  # Frames the plots
        self._canvas = FigureCanvas(self._figure)
        self._canvas.setMinimumWidth(100)  # So user can't resize it to nothing

    def _draw_plot(self):
        """Draw the grid preview plot."""
        self._ax = self._figure.add_subplot(111)
        self._ax.set_title(self._grid_name)
        self._ax.set_aspect('equal', adjustable='box')
        self._ax.xaxis.set_major_locator(AutoLocator())

        if self._cogrid is None:
            return  # Just leave the plot empty with default axes if no CGrid
        self._draw_grid_frame()
        self._draw_triad()
        self._label_sides()

    def _draw_grid_frame(self):
        """Draw the origin and border of the CGrid."""
        # Plot a point at the origin
        self._ax.plot(self._origin_x, self._origin_y, color='lightskyblue', marker='o', markersize=3)
        # Add grid boundary rectangle to the plot
        rect = Rectangle(
            (self._origin_x, self._origin_y),
            self._i_width,
            self._j_height,
            edgecolor='lightskyblue',
            fill=False,
            lw=2,
            angle=self._angle
        )
        self._ax.add_patch(rect)

    def _draw_triad(self):
        """Add the i-j triad to the plot at the origin."""
        # Use 1/4 of the average of the total i width and total j height as the arrow length. Start the arrow at a
        # point that is 5% along the total length of the arc so our labels don't overlap.
        av_total_length = (self._i_width + self._j_height) / 2
        arrow_buffer = av_total_length * 0.05
        arrow_length = av_total_length * 0.25
        theta = math.radians(self._angle)

        # Find the start and end points of the arrow along the i axis
        start_x = self._origin_x + arrow_buffer
        i_start_x = self._origin_x + math.cos(theta) * (start_x - self._origin_x)
        i_start_y = self._origin_y + math.sin(theta) * (start_x - self._origin_x)
        end_x = i_start_x + arrow_length
        i_end_x = i_start_x + math.cos(theta) * (end_x - i_start_x)
        i_end_y = i_start_y + math.sin(theta) * (end_x - i_start_x)

        # Find the start and end points of the arrow along the j axis
        start_y = self._origin_y + arrow_buffer
        j_start_x = self._origin_x - math.sin(theta) * (start_y - self._origin_y)
        j_start_y = self._origin_y + math.cos(theta) * (start_y - self._origin_y)
        end_y = j_start_y + arrow_length
        j_end_x = j_start_x - math.sin(theta) * (end_y - j_start_y)
        j_end_y = j_start_y + math.cos(theta) * (end_y - j_start_y)

        # Plot the arrows
        self._ax.annotate(
            'I',
            xy=(i_end_x, i_end_y),
            xycoords='data',
            xytext=(i_start_x, i_start_y),
            textcoords='data',
            arrowprops=dict(facecolor='black', arrowstyle='simple'),
            horizontalalignment='center',
            verticalalignment='center_baseline',
            zorder=3.0
        )
        self._ax.annotate(
            'J',
            xy=(j_end_x, j_end_y),
            xycoords='data',
            xytext=(j_start_x, j_start_y),
            textcoords='data',
            arrowprops=dict(facecolor='black', arrowstyle='simple'),
            horizontalalignment='center',
            verticalalignment='center_baseline',
            zorder=3.0
        )

    def _label_sides(self):
        """Add labels to each side of the plotted CGrid."""
        theta = math.radians(self._angle)
        # Find the midpoints of the i and j axes
        # j-axis (side 1)
        mid_grid = self._origin_y + self._j_height / 2  # Midpoint of j-axis in grid space
        mid_x_j = self._origin_x - math.sin(theta) * (mid_grid - self._origin_y)
        mid_y_j = self._origin_y + math.cos(theta) * (mid_grid - self._origin_y)
        # i-axis (side 2)
        mid_grid = self._origin_x + self._i_width / 2  # Midpoint of i-axis in grid space
        mid_x_i = self._origin_x + math.cos(theta) * (mid_grid - self._origin_x)
        mid_y_i = self._origin_y + math.sin(theta) * (mid_grid - self._origin_x)

        # Compute the starting points of the other two axes (sides 3 and 4)
        # Side 3 (opposite side 1)
        mid_grid = self._origin_x + self._i_width  # End point of i-axis (side 2) in grid space
        start_x_3 = self._origin_x + math.cos(theta) * (mid_grid - self._origin_x)
        start_y_3 = self._origin_y + math.sin(theta) * (mid_grid - self._origin_x)
        # Side 4 (opposite side 2)
        mid_grid = self._origin_y + self._j_height  # End point of j-axis (side 1) in grid space
        start_x_4 = self._origin_x - math.sin(theta) * (mid_grid - self._origin_y)
        start_y_4 = self._origin_y + math.cos(theta) * (mid_grid - self._origin_y)

        # Now compute the midpoints on sides 2 and 4
        # Side 3
        mid_grid = start_y_3 + self._j_height / 2  # Midpoint of side 3 in grid space
        mid_x_3 = start_x_3 - math.sin(theta) * (mid_grid - start_y_3)
        mid_y_3 = start_y_3 + math.cos(theta) * (mid_grid - start_y_3)
        # Side 4
        mid_grid = start_x_4 + self._i_width / 2  # Midpoint of side 4 in grid space
        mid_x_4 = start_x_4 + math.cos(theta) * (mid_grid - start_x_4)
        mid_y_4 = start_y_4 + math.sin(theta) * (mid_grid - start_x_4)

        # Now draw the labels at the computed midpoints
        self._ax.text(mid_x_j, mid_y_j, 'Side 1')
        self._ax.text(mid_x_i, mid_y_i, 'Side 2')
        self._ax.text(mid_x_3, mid_y_3, 'Side 3')
        self._ax.text(mid_x_4, mid_y_4, 'Side 4')

    def generate_grid_preview_plot(self, layout):
        """Generate the grid preview plot.

        Args:
            layout (QBoxLayout): The layout to append the plot to
        """
        self._initialize_plot()
        self._draw_plot()
        layout.addWidget(self._canvas)
