"""The SRH-2D plots dialog."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import webbrowser

# 2. Third party modules
from matplotlib.backends.backend_qt5agg import FigureCanvas, NavigationToolbar2QT as NavigationToolbar
import numpy as np
from PySide2.QtCore import Qt
from PySide2.QtWidgets import (
    QCheckBox, QDialogButtonBox, QGroupBox, QHBoxLayout, QLabel, QLineEdit, QListWidget, QListWidgetItem, QVBoxLayout,
    QWidget
)

# 3. Aquaveo modules
from xms.guipy.dialogs.xms_parent_dlg import XmsDlg
from xms.guipy.validators.qx_double_validator import QxDoubleValidator

# 4. Local modules
from xms.srh.gui.simulation_plots import SimulationPlots

# plot type to dict of: 'data': list[data_min, data_max], 'specified': list[specified_min, specified_max]
RangesDict = dict[str, dict[str, tuple[float, float]]]


class SimulationPlotsDialog(XmsDlg):
    """A dialog for showing SRH-2D Output plots."""
    def __init__(self, win_cont: QWidget, inf_files: list[str], enable_sediment: bool):  # pragma: no cover
        """Initializes the dialog, sets up the ui.

        Args:
            win_cont: Parent dialog
            inf_files: full path to inf_file
            enable_sediment: True if sediment is enabled.
        """
        super().__init__(win_cont, 'xms.srh.gui.simulation_plots_dialog')
        self._enable_sediment = enable_sediment
        self._plot_list = [
            'Net_Q/INLET_Q', 'Mass Balance', 'Wet Elements', 'Monitor Points WSE', 'Monitor Points Z',
            'Monitor Lines Q', 'Monitor Lines QS'
        ]
        self._simulation_plots = SimulationPlots(inf_files)
        self._max_x = self._simulation_plots.get_max_time()
        self._y_ranges: RangesDict = {}
        self._y_ranges_log: RangesDict = {}
        self._canvas = None  # matplotlib.backends.backend_qt5agg FigureCanvas
        self._toolbar = None  # matplotlib navigation toolbar

        self.help_url = 'https://www.xmswiki.com/wiki/SMS:SRH-2D_Plots'
        self.widgets = {}
        self.setWindowTitle('SRH-2D Simulation Plots')
        self._find_all_y_ranges()
        self._setup_ui()

        # set up the initial state of the dialog
        self._on_specify_x_range(Qt.Unchecked)
        self._on_specify_y_range(Qt.Unchecked)

    def _setup_ui(self) -> None:
        """Sets up the dialog controls."""
        self.widgets['main_vert_layout'] = QVBoxLayout()
        self.setLayout(self.widgets['main_vert_layout'])

        self._setup_left_and_right_layouts()

        # left side of the dialog
        self._setup_scenario_list()
        self._setup_plot_list()
        self._setup_legend_checkbox()
        self._setup_x_axis_widgets()
        self._setup_y_axis_widgets()

        # right side of the dialog
        self._setup_plot()

        # bottom of the dialog
        self._setup_ui_bottom_button_box()

    def _find_all_y_ranges(self) -> None:
        """Finds all the y axis minimums and maximums."""
        for plot_type in self._plot_list:
            self._create_plot(plot_type, None, None, None, None)
            range = self._simulation_plots.ax.get_ybound()
            self._y_ranges[plot_type] = {'data': range, 'specified': range}

            # For log, get non-zero min and max. See https://stackoverflow.com/a/7164681/5666265
            a = np.asarray([line.get_ydata() for line in self._simulation_plots.ax.get_lines()])
            if a.size == 0:
                log_min, log_max = 0, 1  # Default values, change if necessary
            else:
                log_min = np.ma.masked_equal(np.abs(a), 0.0, copy=False).min()
                log_max = np.ma.masked_equal(np.abs(a), 0.0, copy=False).max()
            self._y_ranges_log[plot_type] = {'data': (log_min, log_max), 'specified': (log_min, log_max)}

    def _setup_plot(self) -> None:
        """Sets up the plot."""
        self._canvas = FigureCanvas(self._simulation_plots.figure)
        self._canvas.setMinimumWidth(300)  # So user can't resize it to nothing
        self._toolbar = NavigationToolbar(self._canvas, self)
        # Remove the "Configure subplots" button. We aren't really sure what this does.
        for x in self._toolbar.actions():
            if x.text() == 'Subplots':
                self._toolbar.removeAction(x)
        self.widgets['right_vert_layout'].addWidget(self._toolbar)
        self.widgets['right_vert_layout'].addWidget(self._canvas)

    def _setup_x_axis_widgets(self) -> None:
        """Sets up the x-axis widgets."""
        self.widgets['x_range_group'] = QGroupBox('X axis')
        self.widgets['x_range_group'].setMaximumWidth(200)
        self.widgets['left_vert_layout'].addWidget(self.widgets['x_range_group'])
        self.widgets['x_range_vlayout'] = QVBoxLayout()
        self.widgets['x_range_group'].setLayout(self.widgets['x_range_vlayout'])

        self.widgets['x_axis_specify_range'] = QCheckBox('Specify range')
        self.widgets['x_axis_specify_range'].stateChanged.connect(self._on_specify_x_range)
        self.widgets['x_range_vlayout'].addWidget(self.widgets['x_axis_specify_range'])

        self.widgets['min_x_label'] = QLabel('Minimum:')
        self.widgets['x_range_vlayout'].addWidget(self.widgets['min_x_label'])
        self.widgets['min_x_edit'] = QLineEdit('0.0')
        self.widgets['min_x_edit'].setValidator(QxDoubleValidator(0.0, self._max_x))
        self.widgets['min_x_edit'].editingFinished.connect(self._on_min_max_x_changed)
        self.widgets['x_range_vlayout'].addWidget(self.widgets['min_x_edit'])

        self.widgets['max_x_label'] = QLabel('Maximum:')
        self.widgets['x_range_vlayout'].addWidget(self.widgets['max_x_label'])
        self.widgets['max_x_edit'] = QLineEdit(str(self._max_x))
        self.widgets['max_x_edit'].setValidator(QxDoubleValidator(0.0, self._max_x))
        self.widgets['max_x_edit'].editingFinished.connect(self._on_min_max_x_changed)
        self.widgets['x_range_vlayout'].addWidget(self.widgets['max_x_edit'])

    def _setup_y_axis_widgets(self) -> None:
        """Sets up the y-axis widgets."""
        self.widgets['y_range_group'] = QGroupBox('Y axis')
        self.widgets['y_range_group'].setMaximumWidth(200)
        self.widgets['left_vert_layout'].addWidget(self.widgets['y_range_group'])
        self.widgets['y_range_vlayout'] = QVBoxLayout()
        self.widgets['y_range_group'].setLayout(self.widgets['y_range_vlayout'])

        self.widgets['y_axis_specify_range'] = QCheckBox('Specify range')
        self.widgets['y_axis_specify_range'].stateChanged.connect(self._on_specify_y_range)
        self.widgets['y_range_vlayout'].addWidget(self.widgets['y_axis_specify_range'])

        mn, mx = self._y_ranges[self._plot_list[0]]['specified']
        validator = QxDoubleValidator(mn, mx, parent=self)

        self.widgets['min_y_label'] = QLabel('Minimum:')
        self.widgets['y_range_vlayout'].addWidget(self.widgets['min_y_label'])
        self.widgets['min_y_edit'] = QLineEdit(validator.fixup(str(mn)))
        self.widgets['min_y_edit'].setValidator(validator)
        self.widgets['min_y_edit'].editingFinished.connect(self._on_min_max_y_changed)
        self.widgets['y_range_vlayout'].addWidget(self.widgets['min_y_edit'])

        self.widgets['max_y_label'] = QLabel('Maximum:')
        self.widgets['y_range_vlayout'].addWidget(self.widgets['max_y_label'])
        self.widgets['max_y_edit'] = QLineEdit(validator.fixup(str(mx)))
        self.widgets['max_y_edit'].setValidator(validator)
        self.widgets['max_y_edit'].editingFinished.connect(self._on_min_max_y_changed)
        self.widgets['y_range_vlayout'].addWidget(self.widgets['max_y_edit'])

        self.widgets['left_vert_layout'].addStretch()

        # Log checkbox
        self.widgets['y_axis_log_scale'] = QCheckBox('Log scale')
        self.widgets['y_axis_log_scale'].stateChanged.connect(self._on_log_scale)
        self.widgets['y_range_vlayout'].addWidget(self.widgets['y_axis_log_scale'])

    def _setup_legend_checkbox(self) -> None:
        """Sets up the legend check box."""
        self.widgets['tog_show_legend'] = QCheckBox('Show legend')
        self.widgets['tog_show_legend'].setChecked(True)
        self.widgets['tog_show_legend'].toggled.connect(self._on_show_legend)
        self.widgets['left_vert_layout'].addWidget(self.widgets['tog_show_legend'])

    def _setup_plot_list(self) -> None:
        """Adds the plot list to the dialog."""
        self.widgets['plot_type_label'] = QLabel('Plots:')
        self.widgets['left_vert_layout'].addWidget(self.widgets['plot_type_label'])
        self.widgets['plot_list'] = QListWidget()
        self.widgets['plot_list'].setMaximumWidth(200)
        self.widgets['plot_list'].addItems(self._plot_list)
        self.widgets['plot_list'].setCurrentRow(0)
        self.widgets['plot_list'].currentItemChanged.connect(self._on_plot_changed)
        self.widgets['left_vert_layout'].addWidget(self.widgets['plot_list'])

    def _setup_scenario_list(self) -> None:
        """Sets up the scenario list."""
        # list control and 2 edit fields for left side of dialog
        # if we have multiple scenarios then create list box with the scenario names with check boxes
        if len(self._simulation_plots.scenarios.keys()) > 1:
            self.widgets['scenario_label'] = QLabel('Scenarios:')
            self.widgets['left_vert_layout'].addWidget(self.widgets['scenario_label'])
            self.widgets['scenario_list'] = QListWidget()
            self.widgets['scenario_list'].setMaximumWidth(200)
            self.widgets['scenario_list'].itemChanged.connect(self._on_list_state_changed)
            self.widgets['left_vert_layout'].addWidget(self.widgets['scenario_list'])
            for i, sim in enumerate(self._simulation_plots.scenarios.keys()):
                item = QListWidgetItem(sim)
                item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
                if i == 0:
                    item.setCheckState(Qt.Checked)
                else:
                    item.setCheckState(Qt.Unchecked)
                self.widgets['scenario_list'].addItem(item)

    def _setup_left_and_right_layouts(self) -> None:
        """Sets up the left and right layouts."""
        # horizontal layout
        self.widgets['main_horiz_layout'] = QHBoxLayout()
        self.widgets['main_vert_layout'].addLayout(self.widgets['main_horiz_layout'])
        # 2 vertical layouts: 1 for list box and axis options, and 2 for the plot on the right
        self.widgets['left_vert_layout'] = QVBoxLayout()
        self.widgets['main_horiz_layout'].addLayout(self.widgets['left_vert_layout'], stretch=1)
        self.widgets['right_vert_layout'] = QVBoxLayout()
        self.widgets['main_horiz_layout'].addLayout(self.widgets['right_vert_layout'])

    def _setup_ui_bottom_button_box(self):
        """Add buttons to the bottom of the dialog."""
        # Add Import and Export buttons
        self.widgets['btn_horiz_layout'] = QHBoxLayout()
        self.widgets['btn_box'] = QDialogButtonBox()
        self.widgets['btn_box'].setOrientation(Qt.Horizontal)
        self.widgets['btn_box'].setStandardButtons(QDialogButtonBox.Close | QDialogButtonBox.Help)
        self.widgets['btn_box'].accepted.connect(self.accept)
        self.widgets['btn_box'].rejected.connect(self.reject)
        self.widgets['btn_box'].helpRequested.connect(self._help_requested)
        self.widgets['btn_horiz_layout'].addWidget(self.widgets['btn_box'])
        self.widgets['main_vert_layout'].addLayout(self.widgets['btn_horiz_layout'])

    def _help_requested(self):  # pragma: no cover
        """Called when the Help button is clicked."""
        webbrowser.open(self.help_url)

    def _get_min_max(self, checkbox: str, min_edit: str, max_edit: str) -> tuple[float, float]:
        """Returns the min and max x as a tuple.

        They will be None if the min and max isn't specified.

        Returns:
            (tuple): tuple containing:
                mn (float): The minimum x, or None if not specified.
                mx (float): The maximum x, or None if not specified.
        """
        mn = None
        mx = None
        if self.widgets[checkbox].isChecked():
            mn = float(self.widgets[min_edit].text())
            mx = float(self.widgets[max_edit].text())
        return mn, mx

    def _get_x_min_max(self) -> tuple[float, float]:
        """Returns the min and max x as a tuple.

        They will be None if the min and max isn't specified.

        Returns:
            (tuple): tuple containing:
                mn (float): The minimum x, or None if not specified.
                mx (float): The maximum x, or None if not specified.
        """
        return self._get_min_max('x_axis_specify_range', 'min_x_edit', 'max_x_edit')

    def _get_y_min_max(self) -> tuple[float, float]:
        """Returns the y min and max as a tuple.

        They will be None if the min and max isn't specified.

        Returns:
            (tuple): tuple containing:
                mn (float): The minimum y, or None if not specified.
                mx (float): The maximum y, or None if not specified.
        """
        return self._get_min_max('y_axis_specify_range', 'min_y_edit', 'max_y_edit')

    def _on_list_state_changed(self):
        """Update the list of scenarios to plot based on the check state of the items in the list box."""
        # get all checked simulations
        self._simulation_plots.scenarios_to_plot = []
        for i in range(self.widgets['scenario_list'].count()):
            item = self.widgets['scenario_list'].item(i)
            if item.checkState() == Qt.Checked:
                self._simulation_plots.scenarios_to_plot.append(item.text())
        plot_type = self._plot_list[self.widgets['plot_list'].currentRow()]
        self._create_and_draw_plot(plot_type)

    def _on_show_legend(self):
        """Called when the show legend toggle is checked."""
        if self.widgets['tog_show_legend'].isChecked():
            self._simulation_plots.show_legend = True
        else:
            self._simulation_plots.show_legend = False
        plot_type = self._plot_list[self.widgets['plot_list'].currentRow()]
        self._create_and_draw_plot(plot_type)

    def _on_specify_x_range(self, state: int) -> None:
        """Called when the specify x range checkbox is checked.

        Args:
            state: State of the checkbox.
        """
        self.widgets['min_x_label'].setEnabled(state == Qt.Checked)
        self.widgets['min_x_edit'].setEnabled(state == Qt.Checked)
        self.widgets['max_x_label'].setEnabled(state == Qt.Checked)
        self.widgets['max_x_edit'].setEnabled(state == Qt.Checked)

        plot_type = self._plot_list[self.widgets['plot_list'].currentRow()]
        self._create_and_draw_plot(plot_type)

    def _on_specify_y_range(self, state: int) -> None:
        """Called when the specify x range checkbox is checked.

        Args:
            state: State of the checkbox.
        """
        self.widgets['min_y_label'].setEnabled(state == Qt.Checked)
        self.widgets['min_y_edit'].setEnabled(state == Qt.Checked)
        self.widgets['max_y_label'].setEnabled(state == Qt.Checked)
        self.widgets['max_y_edit'].setEnabled(state == Qt.Checked)

        plot_type = self._plot_list[self.widgets['plot_list'].currentRow()]
        self._create_and_draw_plot(plot_type)

    def _on_log_scale(self, checked: int) -> None:
        """Called when the log scale checkbox is clicked.

        Args:
            checked: Qt.CheckState (Qt.Unchecked = 0, Qt.PartiallyChecked = 1, Qt.Checked = 2)
        """
        plot_type = self._plot_list[self.widgets['plot_list'].currentRow()]
        self._save_and_set_min_max_y_for_plot(plot_type, plot_type, log_changed=True)
        self._create_and_draw_plot(plot_type)

    def _save_and_set_min_max_y_for_plot(self, old_plot_type: str, new_plot_type: str, log_changed: bool) -> None:
        """Saves the y min/max for the old plot and sets the new y min/max for the plot.

        Args:
            old_plot_type: The old plot type.
            new_plot_type: The new plot type.
            log_changed: True if the log scale checkbox state changed.
        """
        checked = self.widgets['y_axis_log_scale'].checkState()
        if log_changed:
            if checked == Qt.Checked:  # Going from linear to log
                self._save_and_set_min_max_y(old_plot_type, new_plot_type, self._y_ranges, self._y_ranges_log)
            else:  # Going from log to linear
                self._save_and_set_min_max_y(old_plot_type, new_plot_type, self._y_ranges_log, self._y_ranges)
        else:
            if checked == Qt.Checked:  # Using log scale
                self._save_and_set_min_max_y(old_plot_type, new_plot_type, self._y_ranges_log, self._y_ranges_log)
            else:  # Using linear scale
                self._save_and_set_min_max_y(old_plot_type, new_plot_type, self._y_ranges, self._y_ranges)

    def _save_and_set_min_max_y(
        self, old_plot_type: str, new_plot_type: str, old_ranges_dict: RangesDict, new_ranges_dict: RangesDict
    ) -> None:
        """
        Save the current y min/max and set the new y min/max.

        Args:
            old_plot_type: The old plot type.
            new_plot_type: The new plot type.
            old_ranges_dict: The dict to save the current ranges to.
            new_ranges_dict: The dict to get the current ranges from.
        """
        # Save current min/max
        mn, mx = float(self.widgets['min_y_edit'].text()), float(self.widgets['max_y_edit'].text())
        old_ranges_dict[old_plot_type]['specified'] = (mn, mx)

        # Set ranges
        data_min, data_max = new_ranges_dict[new_plot_type]['data']
        self.widgets['min_y_edit'].validator().setRange(data_min, data_max)
        self.widgets['max_y_edit'].validator().setRange(data_min, data_max)

        # Set min/max
        validator = self.widgets['min_y_edit'].validator()
        spec_min, spec_max = new_ranges_dict[new_plot_type]['specified']
        self.widgets['min_y_edit'].setText(validator.fixup(str(spec_min)))
        self.widgets['max_y_edit'].setText(validator.fixup(str(spec_max)))

    def _set_edit_field_without_signalling(self, widget_name: str, text: str) -> None:
        """Set the text of a QLineEdit without emitting the editingFinished signal.

        Args:
            widget_name: Name of the widget.
            text: The text to set.
        """
        self.widgets[widget_name].blockSignals(True)
        self.widgets[widget_name].setText(text)
        self.widgets[widget_name].blockSignals(False)

    def _on_min_max_x_changed(self) -> None:
        """Called when the x min or max has changed."""
        # make sure the min/max are not inverted
        mn, mx = self._get_x_min_max()
        if mn is not None and mx is not None and mn > mx:
            mn, mx = mx, mn
            self._set_edit_field_without_signalling('min_x_edit', str(mn))
            self._set_edit_field_without_signalling('max_x_edit', str(mx))

        plot_type = self._plot_list[self.widgets['plot_list'].currentRow()]
        self._create_and_draw_plot(plot_type)

    def _on_min_max_y_changed(self) -> None:
        """Called when the y min or max has changed."""
        # make sure the min/max are not inverted
        mn, mx = self._get_y_min_max()
        if mn is not None and mx is not None and mn > mx:
            mn, mx = mx, mn
            self._set_edit_field_without_signalling('min_y_edit', str(mn))
            self._set_edit_field_without_signalling('max_y_edit', str(mx))

        plot_type = self._plot_list[self.widgets['plot_list'].currentRow()]
        self._y_ranges[plot_type]['specified'] = (mn, mx)
        self._create_and_draw_plot(plot_type)

    def _on_plot_changed(self, current: QListWidgetItem, previous: QListWidgetItem) -> None:
        """
        Called when the selected item in the list widget changes.

        Args:
            current: current list item
            previous:  previous list item
        """
        idx = self.widgets['plot_list'].currentRow()
        if idx < 0 or idx >= len(self._plot_list):
            return

        # plot_type = self._plot_list[idx]
        new_plot_type = current.text()
        old_plot_type = previous.text()

        self._save_and_set_min_max_y_for_plot(old_plot_type, new_plot_type, log_changed=False)
        self._create_and_draw_plot(new_plot_type)

    def _create_and_draw_plot(self, plot_type: str):
        min_x, max_x = self._get_x_min_max()
        min_y, max_y = self._get_y_min_max()
        self._create_plot(plot_type, min_x, max_x, min_y, max_y)
        self._canvas.draw()

    def _create_plot(self, plot_type: str, min_x: float, max_x: float, min_y: float, max_y: float) -> None:
        """Creates the plot based on the plot type.

        Args:
            plot_type: The plot type.
            min_x: The minimum x.
            max_x: The maximum x.
            min_y: The minimum y.
            max_y: The maximum y.
        """
        log = self.widgets['y_axis_log_scale'].isChecked() if 'y_axis_log_scale' in self.widgets else False
        if plot_type == 'Net_Q/INLET_Q':
            self._simulation_plots.create_net_q_plot(min_x, max_x, min_y, max_y, log)
        elif plot_type == 'Mass Balance':
            self._simulation_plots.create_mass_balance_plot(min_x, max_x, min_y, max_y, log)
        elif plot_type == 'Wet Elements':
            self._simulation_plots.create_wet_elements_plot(min_x, max_x, min_y, max_y, log)
        elif plot_type == 'Monitor Points WSE':
            self._simulation_plots.create_monitor_point_plot(min_x, max_x, min_y, max_y, log, 'wse')
        elif plot_type == 'Monitor Points Z':
            self._simulation_plots.create_monitor_point_plot(min_x, max_x, min_y, max_y, log)
        elif plot_type == 'Monitor Lines Q':
            self._simulation_plots.create_monitor_line_q_plot(min_x, max_x, min_y, max_y, log)
        elif plot_type == 'Monitor Lines QS' and self._enable_sediment:
            self._simulation_plots.create_monitor_line_qs_plot(min_x, max_x, min_y, max_y, log)
