"""The levee plots dialog."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os
import webbrowser

# 2. Third party modules
import folium
from PySide2.QtCore import Qt, QUrl
from PySide2.QtWebEngineWidgets import QWebEngineView
from PySide2.QtWidgets import (
    QDialogButtonBox, QHBoxLayout, QVBoxLayout
)

# 3. Aquaveo modules
from xms.api.dmi import XmsEnvironment as XmEnv
from xms.gdal.utilities import gdal_utils as gu
from xms.guipy.dialogs.xms_parent_dlg import XmsDlg

# 4. Local modules
from xms.adcirc.data import bc_data as bcd
from xms.adcirc.gui.line_map_viewer_dialog import get_line_extents
from xms.adcirc.mapping import mapping_util as map_util


class MappedBCViewerDlg(XmsDlg):
    """A dialog for showing station point solution plots."""
    def __init__(self, parent, bc_data, mapped_data, ugrid, xd):
        """Initializes the dialog, sets up the ui.

        Args:
            parent (:obj:`QWidget`): Parent dialog
            bc_data (:obj:`BcData`): Dataset containing the boundary conditions parameters
            mapped_data (:obj:`MappedBcData`): Dataset containing the boundary conditions parameters
            ugrid (UGrid): Ugrid used for mapping
            xd (XmsData): Object for communicating with XMS
        """
        super().__init__(parent, 'xms.adcirc.gui.mapped_bc_viewer_dlg')

        self.figure = None  # matplotlib.figure Figure
        self.canvas = None  # matplotlib.backends.backend_qt5agg FigureCanvas
        self.ax = None  # matplotlib Axes
        self.bc_data = bc_data
        self.mapped_data = mapped_data
        self.ugrid = ugrid

        self._xd = xd
        self._first_time = True
        self._sel_row = None

        self._display_colors = ['gray', 'darkblue', 'darkgreen', 'green', 'cadetblue', 'darkpurple', 'lightblue',
                                'red', 'black', 'pink']
        temp_dir = XmEnv.xms_environ_process_temp_directory()
        self._url_file = os.path.join(temp_dir, f'adcirc_mapped_viewer_{os.getpid()}.html')
        self._global_extents = [[float('inf'), float('inf')], [float('-inf'), float('-inf')]]

        self._get_bc_locations()

        # self.help_url = 'https://www.xmswiki.com/wiki/SMS:ADCIRC#Output_files'
        self.widgets = {}
        self.setWindowTitle('ADCIRC Boundary Condition Viewer')
        self._setup_ui()

    def _needs_axis_swap(self, wkt) -> bool:
        """Returns True if the axis order is Y,X (e.g., Northing,Easting)."""
        axis_lines = [line.strip() for line in wkt.splitlines() if line.strip().startswith("AXIS")]

        if len(axis_lines) >= 2:
            first_axis = axis_lines[0].lower()
            second_axis = axis_lines[1].lower()

            # EPSG axis order is often Y,X like ["northing", "easting"]
            if "northing" in first_axis or "latitude" in first_axis:
                if "easting" in second_axis or "longitude" in second_axis:
                    return True  # It's Y,X → needs swapping
        return False  # Already traditional (X,Y)

    def _get_bc_locations(self):
        """Gets locations of boundary conditions."""
        coord_sys = self._xd.coordinate_system
        is_geo = coord_sys == 'GEOGRAPHIC'
        is_m = 'METER' in self._xd.horizontal_units
        if not is_geo:
            swap_input_x_y = self._needs_axis_swap(self._xd.display_wkt)
            transform = gu.get_coordinate_transformation(self._xd.display_wkt, gu.wkt_from_epsg(4326))

        self._num_rows = len(self.mapped_data.nodestrings.comp_id)
        self._c_ids = [int(self.mapped_data.nodestrings['comp_id'][idx].data.item()) for idx in range(self._num_rows)]
        self._bc_types = {id: int(self.bc_data.arcs['type'].loc[id].data.item()) for id in self._c_ids}
        self._bc_locs = []
        self._bc_extents = []

        done_levee_c_id = set()
        for row in range(self._num_rows):
            c_id = int(self.mapped_data.nodestrings['comp_id'][row].data.item())
            if c_id in done_levee_c_id:
                continue

            bc_nodes = []
            if self._bc_types[c_id] == bcd.LEVEE_INDEX:
                bc_nodes.append(self.mapped_data.levees['Node1 Id'].loc[c_id])
                bc_nodes.append(self.mapped_data.levees['Node2 Id'].loc[c_id])
                done_levee_c_id.add(c_id)
            else:
                start_idx = int(self.mapped_data.nodestrings.nodes_start_idx[row])
                node_count = int(self.mapped_data.nodestrings.node_count[row])
                bc_nodes = [self.mapped_data.nodes.id[start_idx:start_idx + node_count]]

            for nodestring in bc_nodes:
                tmp_locs = [self.ugrid.get_point_location(node - 1) for node in nodestring]

                if not is_geo:
                    if coord_sys in ['NONE', '']:  # Local space
                        converter = map_util.meters_to_decimal_degrees if is_m else map_util.feet_to_decimal_degrees
                        locs = [(converter(pt[0], 0.0), converter(pt[1], 0.0)) for pt in tmp_locs]
                    else:  # not in geographic
                        locs = []
                        for pt in tmp_locs:
                            if swap_input_x_y:
                                coords = transform.TransformPoint(pt[1], pt[0])
                            else:
                                coords = transform.TransformPoint(pt[0], pt[1])
                            locs.append((coords[0], coords[1]))
                else:
                    locs = [(pt[1], pt[0]) for pt in tmp_locs]

                min_lat, min_lon, max_lat, max_lon = get_line_extents(locs)
                self._add_line_to_extents(min_lat, min_lon, max_lat, max_lon)

                self._bc_extents.append([[min_lat, min_lon], [max_lat, max_lon]])
                self._bc_locs.append(locs)

    def _add_line_to_extents(self, min_lat, min_lon, max_lat, max_lon):
        """Store the extents of a single line and update the global extents.

        Args:
            min_lat (:obj:`float`): Minimum latitude of the line
            min_lon (:obj:`float`): Minimum longitude of the line
            max_lat (:obj:`float`): Maximum latitude of the line
            max_lon (:obj:`float`): Maximum longitude of the line
        """
        # Update the global extents
        self._global_extents[0][0] = min(min_lat, self._global_extents[0][0])
        self._global_extents[0][1] = min(min_lon, self._global_extents[0][1])
        self._global_extents[1][0] = max(max_lat, self._global_extents[1][0])
        self._global_extents[1][1] = max(max_lon, self._global_extents[1][1])

    def _get_extents_of_set(self, list_extents):
        """Store the extents of a single line and update the global extents.

        Args:
            list_extents (list): List of extents for all objects

        Returns:
            Extents from the entire group
        """
        # Update the global extents
        min_lat = list_extents[0][0]
        min_lon = list_extents[0][1]
        max_lat = list_extents[0][2]
        max_lon = list_extents[0][3]
        for i in range(1, len(list_extents)):
            min_lat = min(min_lat, list_extents[i][0])
            min_lon = min(min_lon, list_extents[i][1])
            max_lat = max(max_lat, list_extents[i][2])
            max_lon = max(max_lon, list_extents[i][3])

        return [[min_lat, min_lon], [max_lat, max_lon]]

    def update_viewer(self, sel_display_id):
        """Redraw the bc arcs.

        Args:
            sel_display_id (int): The selected row in the spreadsheet
        """
        if not self._first_time and sel_display_id == self._sel_row:
            return

        self._sel_row = sel_display_id

        folium.LayerControl().reset()
        folium_map = folium.Map(tiles='', control_scale=True)

        url = 'http://server.arcgisonline.com/ArcGIS/rest/services/World_Street_Map/MapServer/tile/{z}/{y}/{x}'
        tl = folium.raster_layers.TileLayer(tiles=url, name='ESRI Street Map', attr='ESRI')
        tl.add_to(folium_map)
        url = 'http://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}'
        tl = folium.raster_layers.TileLayer(tiles=url, name='ESRI World Imagery', attr='ESRI', opacity=0.2)
        tl.add_to(folium_map)

        for row in range(self._num_rows):
            bc_type = self._bc_types[self._c_ids[row]]
            line = self._bc_locs[row]
            if len(line) > 1:
                folium.vector_layers.PolyLine(locations=line, color=self._display_colors[bc_type],
                                              weight=5 if self._sel_row == row else 2,
                                              tooltip=folium.Tooltip(f'{row}', sticky=True)).add_to(folium_map)

        folium_map.fit_bounds(self._global_extents if sel_display_id is None else self._bc_extents[sel_display_id])
        folium_map.save(self._url_file)
        self._viewer.setUrl(QUrl.fromLocalFile(self._url_file))
        self._viewer.show()

        self._first_time = False

    def _setup_ui(self):
        """Sets up the dialog controls."""
        self.setWindowTitle('Boundary Conditions Viewer')
        self.setMinimumSize(300, 300)
        self.widgets['vert_layout'] = QVBoxLayout()
        self._viewer = QWebEngineView(self)
        self.widgets['plot_view'] = self._viewer
        self.widgets['vert_layout'].addWidget(self.widgets['plot_view'])

        # add the ok, cancel, help... buttons at the bottom of the dialog
        self._setup_ui_bottom_button_box()

        self.setLayout(self.widgets['vert_layout'])

    def _setup_ui_bottom_button_box(self):
        """Add buttons to the bottom of the dialog."""
        # Add Import and Export buttons
        self.widgets['btn_horiz_layout'] = QHBoxLayout()
        self.widgets['btn_box'] = QDialogButtonBox()
        self.widgets['btn_box'].setOrientation(Qt.Horizontal)
        self.widgets['btn_box'].setStandardButtons(QDialogButtonBox.Close | QDialogButtonBox.Help)
        self.widgets['btn_box'].accepted.connect(self.accept)
        self.widgets['btn_box'].rejected.connect(self.reject)
        self.widgets['btn_box'].helpRequested.connect(self._help_requested)
        self.widgets['btn_horiz_layout'].addWidget(self.widgets['btn_box'])
        self.widgets['vert_layout'].addLayout(self.widgets['btn_horiz_layout'])

    def _help_requested(self):  # pragma: no cover
        """Called when the Help button is clicked."""
        webbrowser.open(self.help_url)
