"""The profile plot tab in the summary dialog."""
__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from itertools import cycle

# 2. Third party modules
from matplotlib.backends.backend_qt5agg import FigureCanvas, NavigationToolbar2QT as NavigationToolbar
from matplotlib.figure import Figure
from matplotlib.patches import Polygon
from PySide2.QtCore import QSize, Qt
from PySide2.QtWidgets import (QCheckBox, QHBoxLayout, QLabel, QListWidget, QListWidgetItem, QPushButton, QVBoxLayout)

# 3. Aquaveo modules
import xms.api._xmsapi.dmi as xmd
from xms.api.tree import tree_util
from xms.data_objects.parameters import FilterLocation
from xms.guipy.dialogs.treeitem_selector import TreeItemSelectorDlg

# 4. Local modules
from xms.bridge.gui.summary_dlg.extracted_sim_data import ExtractedSimData


class ProfileTab:
    """The profile plot tab in the summary dialog."""
    def __init__(self, parent, widgets, xms_data, cross_section_line, plot_sim_data):
        """Initializes the class.

        Args:
            parent (:obj:`QWidget`): Parent dialog
            widgets (:obj:`dict`): Dictionary of widgets
            xms_data (:obj:`XmsData`): Object for retrieving data from XMS
            cross_section_line (:obj:`list`): list of points for the cross section line
            plot_sim_data (:obj:`PlotSimData`): Object for managing data for profile and cross-section plots
        """
        self._parent = parent
        self._parent_widgets = widgets
        self._xms_data = xms_data
        self._is_cross_section = len(cross_section_line) > 0
        self._cross_section_line = cross_section_line
        self._pe_tree = None if not xms_data else xms_data.project_tree
        self._comp = None if not xms_data else xms_data.structure_component
        self._sim_data = plot_sim_data
        self._extracted_sim_data = ExtractedSimData(self._sim_data, self._comp, self._is_cross_section)
        self._polyline = []
        self._widgets = {}
        self._selected_sim = None
        self._updating_times = False  # flag for when we clear and repopulate the times list
        self._figure = None  # matplotlib.figure Figure
        self._canvas = None  # matplotlib.backends.backend_qt5agg FigureCanvas
        self._ax = None  # matplotlib Axes
        self._ax2 = None  # second axis for plot
        self._toolbar = None  # matplotlib navigation toolbar
        self._add_widgets()
        self._get_profile_polyline()

    def _add_widgets(self):
        """Set up the UI."""
        vlayout = self._parent_widgets['tab2_scroll_area_vert_layout']
        if self._is_cross_section:
            vlayout = self._parent_widgets['tab3_scroll_area_vert_layout']
        self._widgets['main_vert_layout'] = vlayout
        self._setup_left_and_right_layouts()
        self._setup_simulations_list()
        self._setup_plot()

    def _setup_left_and_right_layouts(self) -> None:
        """Sets up the left and right layouts."""
        # button to pick a coverage
        if not self._is_cross_section:
            self._widgets['cov_horiz_layout'] = QHBoxLayout()
            self._widgets['main_vert_layout'].addLayout(self._widgets['cov_horiz_layout'])
            self._widgets['btn_select_coverage'] = QPushButton('Select Coverage...')
            self._widgets['btn_select_coverage'].clicked.connect(self._on_btn_select_coverage)
            self._widgets['cov_horiz_layout'].addWidget(self._widgets['btn_select_coverage'])
            self._widgets['label_coverage'] = QLabel('select profile line coverage')
            self._widgets['cov_horiz_layout'].addWidget(self._widgets['label_coverage'])
            self._widgets['cov_horiz_layout'].addStretch()
            self._updated_profile_coverage_label()
        self._widgets['tog_show_velocity'] = QCheckBox('Plot velocity magnitude')
        self._widgets['tog_show_velocity'].stateChanged.connect(self._on_tog_show_velocity)
        self._widgets['main_vert_layout'].addWidget(self._widgets['tog_show_velocity'])
        # horizontal layout
        self._widgets['main_horiz_layout'] = QHBoxLayout()
        self._widgets['main_vert_layout'].addLayout(self._widgets['main_horiz_layout'])
        # 2 vertical layouts 1 for list box and time range controls and 2 for the plot on the right
        self._widgets['left_vert_layout'] = QVBoxLayout()
        self._widgets['main_horiz_layout'].addLayout(self._widgets['left_vert_layout'])
        self._widgets['right_vert_layout'] = QVBoxLayout()
        self._widgets['main_horiz_layout'].addLayout(self._widgets['right_vert_layout'])

    def _setup_simulations_list(self) -> None:
        """Sets up the list of simulations."""
        self._widgets['simulation_label'] = QLabel('Simulations:')
        self._widgets['left_vert_layout'].addWidget(self._widgets['simulation_label'])
        self._widgets['simulation_list'] = QListWidget()
        self._widgets['simulation_list'].setMaximumWidth(200)
        self._widgets['simulation_list'].itemChanged.connect(self._on_list_state_change)
        self._widgets['simulation_list'].itemSelectionChanged.connect(self._on_list_sel_change)
        for sim in self._sim_data.sim_names:
            item = QListWidgetItem(sim)
            item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
            item.setCheckState(Qt.Unchecked)
            self._widgets['simulation_list'].addItem(item)
        self._widgets['left_vert_layout'].addWidget(self._widgets['simulation_list'])

        self._widgets['times_label'] = QLabel('Selected simulation time steps:')
        self._widgets['left_vert_layout'].addWidget(self._widgets['times_label'])
        self._widgets['times_list'] = QListWidget()
        self._widgets['times_list'].setMaximumWidth(200)
        self._widgets['times_list'].itemSelectionChanged.connect(self._on_time_list_selection_change)
        self._widgets['left_vert_layout'].addWidget(self._widgets['times_list'])

    def _setup_plot(self):
        """Set up the plot widgets."""
        self._figure = Figure()
        self._figure.set_layout_engine(layout='tight')
        self._canvas = FigureCanvas(self._figure)
        self._canvas.sizeHint = lambda: QSize(150, 150)  # needed so that the vertical can be shrunk
        self._canvas.setMinimumWidth(150)  # So user can't resize it to nothing
        self._toolbar = NavigationToolbar(self._canvas, self._parent)
        # Remove the "Configure subplots" button. We aren't really sure what this does.
        for x in self._toolbar.actions():
            if x.text() == 'Subplots':
                self._toolbar.removeAction(x)
        self._widgets['right_vert_layout'].addWidget(self._toolbar)
        self._widgets['right_vert_layout'].addWidget(self._canvas)
        self._ax = self._figure.add_subplot(111)

    def _on_list_state_change(self):
        """Changes to the selected simulations."""
        self._update_plots()

    def _on_tog_show_velocity(self):
        """Called when the user toggles the show velocity checkbox."""
        self._update_plots()

    def _on_list_sel_change(self):
        """Changes to the selected simulations."""
        sel_items = self._widgets['simulation_list'].selectedItems()
        if not sel_items:
            return
        self._updating_times = True
        self._widgets['times_list'].clear()
        self._selected_sim = sel_items[0].text()
        times = self._sim_data.sim_data[self._selected_sim].get('times', [])
        sel_time = self._extracted_sim_data.selected_time[self._selected_sim]
        # update the time list
        for time in times:
            item = QListWidgetItem(time)
            self._widgets['times_list'].addItem(item)
            if time == sel_time:
                item.setSelected(True)
        self._updating_times = False

    def _on_time_list_selection_change(self):
        """Changes to the selected time steps."""
        if self._updating_times:
            return
        if not self._selected_sim:
            return
        sel_items = self._widgets['times_list'].selectedItems()
        self._extracted_sim_data.selected_time[self._selected_sim] = sel_items[0].text()
        self._update_plots()

    def _checked_sims(self):
        """Get the checked simulations."""
        checked_sims = []
        for i in range(self._widgets['simulation_list'].count()):
            item = self._widgets['simulation_list'].item(i)
            if item.checkState() == Qt.Checked:
                checked_sims.append(item.text())
        return checked_sims

    def _reset_plot(self):
        """Resets the plot."""
        self._ax.clear()
        if self._ax2:
            self._ax2.remove()
            self._ax2 = None
        self._canvas.draw()

    def _set_plot_axis_labels(self):
        """Sets the plot axis labels."""
        self._ax.set_xlabel('Distance along profile line')
        self._ax.set_ylabel('Elevation')
        if self._widgets['tog_show_velocity'].isChecked():
            self._ax2 = self._ax.twinx()
            self._ax2.set_ylabel('Velocity Magnitude', color='red')
            self._ax2.tick_params(axis='y', labelcolor='red')

    def _plot_sim_data(self):
        """Plot the simulation data."""
        # get all checked simulations
        elev_plotted = False
        checked_sims = self._checked_sims()
        extractor = self._extracted_sim_data
        styles_list = ['solid', 'dotted', 'dashed', 'dashdot']
        styles_cycle = cycle(styles_list)
        for sim in checked_sims:
            style = next(styles_cycle)
            extractor.extract_data_from_sim(sim)
            if sim not in extractor.plot_data:
                continue
            elev = extractor.plot_data[sim].get('elevations', None)
            if elev is not None and not elev_plotted:
                self._ax.plot(elev[0], elev[1], label='Ground Elevation', color='sienna')
                elev_plotted = True
            wse = extractor.plot_data[sim].get('wse', None)
            if wse is not None:
                self._ax.plot(wse[0], wse[1], label=f'{sim} WSE', color='C0', linestyle=style)
            if self._ax2:
                vel_mag = extractor.plot_data[sim].get('velocity_mag', None)
                if vel_mag is not None:
                    self._ax2.plot(
                        vel_mag[0], vel_mag[1], label=f'{sim} Velocity Magnitude', color='red', linestyle=style
                    )

    def _plot_structure(self):
        """Plot the structure."""
        # draw the structure polygon
        extractor = self._extracted_sim_data
        if extractor.structure_polygon:
            poly = Polygon(extractor.structure_polygon, facecolor='gray')
            self._ax.add_patch(poly)
            if extractor.structure_piers:
                for pier in extractor.structure_piers:
                    poly = Polygon(pier, facecolor='gray')
                    self._ax.add_patch(poly)

    def _update_plots(self):
        """Updates the plots."""
        self._reset_plot()
        self._set_plot_axis_labels()
        if not self._polyline:  # must have a profile line
            return
        self._plot_sim_data()
        self._plot_structure()
        self._canvas.draw()

    def _on_btn_select_coverage(self):
        """Signal for when the user clicks the Select Coverage button."""
        # Display a tree item selector dialog.
        selector_dlg = TreeItemSelectorDlg(
            title='Select Profile Line Coverage',
            target_type=xmd.CoverageItem,
            pe_tree=self._pe_tree,
            previous_selection=self._comp.data.data_dict['summary_profile_coverage_uuid'],
            parent=self._parent,
            allow_multi_select=False
        )

        if selector_dlg.exec():
            selected_uuid = selector_dlg.get_selected_item_uuid()
            selected_uuid = '' if selected_uuid is None else selected_uuid
            self._comp.data.data_dict['summary_profile_coverage_uuid'] = selected_uuid
            self._updated_profile_coverage_label()
            self._get_profile_polyline()
            self._update_plots()

    def _updated_profile_coverage_label(self):
        """Updates the coverage label."""
        if not self._comp:
            return
        uuid = self._comp.data.data_dict['summary_profile_coverage_uuid']
        coverage_path = tree_util.build_tree_path(self._pe_tree, uuid)
        self._widgets['label_coverage'].setText(coverage_path)

    def _get_profile_polyline(self):
        """Get the polyline from the first arc in the profile coverage."""
        self._polyline = []
        if self._is_cross_section:
            self._polyline = self._cross_section_line
        elif self._xms_data and self._comp:
            cov_uuid = self._comp.data.data_dict['summary_profile_coverage_uuid']
            coverage = self._xms_data.coverage_from_uuid(cov_uuid)
            if coverage is not None and len(coverage.arcs) > 0:
                locs = coverage.arcs[0].get_points(FilterLocation.PT_LOC_ALL)
                self._polyline = [(p.x, p.y) for p in locs]
        self._extracted_sim_data.set_extraction_polyline(self._polyline)
