"""Retrieve data needed by the Manning's N dialog from XMS."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os

# 2. Third party modules
from PySide2.QtWidgets import QDialog
from shapely.geometry import LineString

# 3. Aquaveo modules
import xms.api._xmsapi.dmi as xmd
from xms.api.tree import tree_util
from xms.constraint import read_grid_from_file
from xms.data_objects.parameters import FilterLocation
from xms.datasets.dataset_reader import DatasetReader
from xms.extractor import UGrid2dDataExtractor
from xms.extractor.ugrid_2d_polyline_data_extractor import UGrid2dPolylineDataExtractor
from xms.gdal.rasters.raster_input import RasterInput
from xms.guipy.dialogs.dataset_selector import DatasetSelector
from xms.guipy.dialogs.message_box import message_with_ok, message_with_ok_cancel
from xms.guipy.dialogs.treeitem_selector import TreeItemSelectorDlg
from xms.snap import SnapExteriorArc

# 4. Local modules


class XmsGetter:
    """Class to get Manning n dialog data from XMS. Can be overridden for testing."""
    def __init__(
        self, query, cov_uuid, sel_arc_ids, elevations=None, stations=None, dataset_name=None, dataset_uuid=''
    ):
        """Initializes the class.

        Args:
            query (:obj:`Query`): XMS interprocess communication object
            cov_uuid (:obj:`str`): The coverage UUID
            sel_arc_ids (:obj:`list[int]`): The selected arc IDs
            elevations (:obj:`list[float]`): Cross section elevations
            stations (:obj:`list[float]`): Cross section stations
            dataset_name (:obj:`str`): Name of the selected Manning's N dataset
            dataset_uuid (:obj:`str`): UUID of the selected Manning's N dataset
        """
        self.query = query
        self.cov_uuid = cov_uuid
        self.sel_arc_ids = sel_arc_ids
        self.elevations = elevations
        self.stations = stations
        self.dataset_name = dataset_name
        self.dataset_uuid = dataset_uuid
        self.null_value = -999.0
        self.raster_input = None
        self.co_grid = None
        self.ts_data = None
        self.mesh_uuid = None
        self._pe_tree = None
        self.extractor = None
        self.snap = None
        self.parent_dlg = None
        self._scatter_selected = False
        self.check_geometry_node_id_gaps = False

        # from shapely import speedups
        # speedups.disable()

    @property
    def pe_tree(self):
        """Returns the project explorer tree.

        Returns:
            See description.
        """
        if self._pe_tree is None:
            self._setup_geometry_tree()
        return self._pe_tree

    @pe_tree.setter
    def pe_tree(self, value):
        """Sets the project explorer tree.

        Args:
            value: The project explorer tree.
        """
        self._pe_tree = value

    def select_dataset(self, parent=None, allow_rasters=False):
        """Get a scalar dataset. On success sets self.stations, self.elevations and self.dataset_name.

        When self.query is None the previously set values can be used.

        Args:
            parent (:obj:`QObject`): The parent object.
            allow_rasters (:obj:`bool`): Allow rasters.

        Returns:
            (:obj:`bool`): True on success
        """
        if not self.query:
            return self.elevations and self.stations and self.dataset_name

        self._scatter_selected = False
        # select dataset
        dset_uuid = self._select_dataset_with_dialog(parent, allow_rasters)
        if not dset_uuid:
            return False
        self.dataset_uuid = dset_uuid

        return self.retrieve_data()
        # self._retrieve_xms_dataset_and_geometry('')
        #
        # # Get selected dataset data from XMS
        # arc = self._find_arc(self.sel_arc_ids[0])
        # if self.co_grid is not None:
        #     snap_to_exterior = True if not self._scatter_selected else False
        #     use_poly_line_extractor = True if not snap_to_exterior else False
        #     return self._retrieve_xms_data_for_arc(arc, snap_to_exterior, use_poly_line_extractor)
        # else:
        #     # Get the arc points
        #     return self._retrieve_raster_data_for_arc(arc)

    def retrieve_data(self, parent=None):
        """Get a scalar dataset. On success sets self.stations, self.elevations and self.dataset_name.

        Args:
            parent (:obj:`QObject`): The parent object.

        Returns:
            (:obj:`bool`): True on success
        """
        self._retrieve_xms_dataset_and_geometry('')

        # Get selected dataset data from XMS
        arc = self._find_arc(self.sel_arc_ids[0])
        if self.co_grid is not None:
            snap_to_exterior = True if not self._scatter_selected else False
            use_poly_line_extractor = True if not snap_to_exterior else False
            return self._retrieve_xms_data_for_arc(arc, snap_to_exterior, use_poly_line_extractor)
        else:
            # Get the arc points
            return self._retrieve_raster_data_for_arc(arc)

    def retrieve_xms_data(self, geom_uuid, snap_to_exterior, use_poly_line_extractor):
        """Use xms.sapi to get selected data.

        Args:
            geom_uuid (:obj:`str`): The uuid of the geometry.
            snap_to_exterior (:obj:`bool`): Whether to snap the arc points to the exterior of the mesh
            use_poly_line_extractor (:obj:`bool`): True if using the polyline extractor
        """
        data = {}
        self._retrieve_xms_dataset_and_geometry(geom_uuid)
        coverage = self.query.item_with_uuid(self.cov_uuid)
        arcs = coverage.arcs
        for arc in arcs:
            if self._retrieve_xms_data_for_arc(arc, snap_to_exterior, use_poly_line_extractor):
                data[arc.id] = self.elevations, self.stations, self.dataset_name
        return data

    def retrieve_xms_data_for_selected_arc(self, geom_uuid):
        """Use xms.sapi to get selected data.

        Args:
            geom_uuid (:obj:`str`): The uuid of the geometry.

        Returns:
            (:obj:`bool`): True on success
        """
        if self.sel_arc_ids:
            arc_id = self.sel_arc_ids[0]
            self._retrieve_xms_dataset_and_geometry(geom_uuid)
            return self._retrieve_xms_data_for_arc(self._find_arc(arc_id), True, False)
        else:
            return False

    def retrieve_xms_data_for_points(self, points, use_poly_line_extractor):
        """Get ts_data on the given points from the co_grid and set self.elevations and self.stations.

        Args:
            points (:obj:`list`): A list of points.
            use_poly_line_extractor (:obj:`bool`): True if using the polyline extractor

        Returns:
            (:obj:`bool`): True on success
        """
        if not self.co_grid or self.ts_data is None or len(self.ts_data) == 0:
            return False

        self.elevations, locations = self._compute_elevations(points, use_poly_line_extractor)
        self.stations = self._compute_stations(locations)

        # return self.elevations and self.stations and self.dataset_name
        return True

    # def _redistribute_vertices(self, geom):
    #     """Create more points on the arc
    #
    #             Args:
    #                 geom (BaseGeometry): LineString or MultiLineString to be redistributed
    #                 distance (float): Interval distance for points
    #
    #             Returns:
    #                 bool: True on success
    #             """
    #     if geom.geom_type == 'LineString':
    #         return LineString(
    #             [geom.interpolate(float(n) / 100, normalized=True)
    #              for n in range(101)])
    #     else:
    #         raise ValueError('unhandled geometry %s', (geom.geom_type,))

    def _redistribute_vertices(self, geom):
        """Create more points on the arc.

        Args:
            geom (:obj:`BaseGeometry`): LineString or MultiLineString to be redistributed

        Returns:
            (:obj:`list`): list of points
        """
        arc_pts = []
        for i in range(100):
            loc = geom.interpolate(float(i) / 100, normalized=True)
            arc_pts.append(loc)
        return arc_pts

    def retrieve_raster_data_for_points(self, points):
        """Get elevation data on the given points from the raster and set self.elevations and self.stations.

        Args:
            points (:obj:`list`): A list of points.

        Returns:
            (:obj:`bool`): True on success
        """
        if self.raster_input is None:
            return False

        arc_pts = []
        for i in range(0, len(points)):
            cur_pt = [points[i].x, points[i].y]
            arc_pts.append(cur_pt)

        ls = LineString(arc_pts)
        arc_pts = self._redistribute_vertices(ls)

        points.clear()
        for pt in arc_pts:
            cur_pt = [pt.x, pt.y]
            points.append(cur_pt)

        self.elevations = self._get_elevations_raster(arc_pts)
        self.stations = self._compute_stations(points)
        if self.dataset_uuid != '':
            dataset = self.query.item_with_uuid(self.dataset_uuid)
            self.dataset_name = os.path.basename(dataset)

        return self.elevations and self.stations and self.dataset_name

    def _retrieve_xms_dataset_and_geometry(self, geom_uuid):
        """Use xmsapi to get the dataset and geometry.

        Args:
            geom_uuid (:obj:`str`): The uuid of the mesh geometry or None if using the dataset's geometry.

        Returns:
            (:obj:`bool`): True on success
        """
        self.raster_input = None
        self.co_grid = None
        self.snap = None  # have to reset this in case the geometry has changed
        dataset = self.query.item_with_uuid(self.dataset_uuid)
        if not dataset:
            return False
        is_dset = False
        if type(dataset) is DatasetReader:
            tn = tree_util.find_tree_node_by_uuid(self.pe_tree, self.dataset_uuid)
            self.dataset_name = tn.name
            is_dset = True

        if geom_uuid:
            self.mesh_uuid = geom_uuid
        elif type(dataset) is DatasetReader:
            self.mesh_uuid = dataset.geom_uuid
        tn = tree_util.find_tree_node_by_uuid(self._pe_tree, self.mesh_uuid)
        if is_dset and tn is not None:
            if tn.item_typename == 'TI_SCAT2D':
                self._scatter_selected = True
            if self.parent_dlg and tn.num_points > 500000:
                msg = 'The mesh/scatter associated with the selected dataset has more than ' \
                      '500,000 points. This may take significant time to process and get cross ' \
                      'section elevations.\nIt is recommended that you trim or filter this data ' \
                      'to have fewer than 500,000 points.\n\n Continue?'
                app_name = os.environ.get('XMS_PYTHON_APP_NAME')
                if not message_with_ok_cancel(self.parent_dlg, msg, app_name):
                    return False

        if is_dset and dataset.null_value is not None:
            self.null_value = dataset.null_value
        elif is_dset is False:
            self.mesh_uuid = None
            self.raster_input = RasterInput(dataset)
            self.null_value = self.raster_input.nodata_value

        if is_dset:
            self.ts_data = dataset.values[-1]  # Assuming the last time step
            self.co_grid = self._get_co_grid(self.mesh_uuid)
        else:
            pass

        # Reset the extractor because we are using a different UGrid
        self.extractor = None
        return self.co_grid is not None or self.raster_input is not None

    def _get_co_grid(self, geom_uuid):
        """Get a CoGrid from XMS.

        Args:
            geom_uuid (:obj:`str`): UUID of the desired grid

        Returns:
            (:obj:`xms.constraint.Grid`): The requested grid, or None on failure
        """
        grid_data = self.query.item_with_uuid(geom_uuid)
        grid_file = grid_data.cogrid_file
        co_grid = read_grid_from_file(grid_file)
        if self.check_geometry_node_id_gaps:
            max_node_id = grid_data._instance.GetMaxNodeId()
            npts = co_grid.ugrid.point_count
            if npts != max_node_id:
                msg = 'Gaps detected in the node numbers of the mesh. Renumber the mesh nodes to use ' \
                      'the mesh for elevations.'
                if self.parent_dlg:
                    app_name = os.environ.get('XMS_PYTHON_APP_NAME')
                    message_with_ok(self.parent_dlg, msg, app_name)
                    co_grid = None
                else:
                    raise RuntimeError(msg)
        return co_grid

    def _get_arc_locations(self, arc, snap_to_exterior):
        """Get the arc locations.

        Args:
            arc (:obj:`xms.data_objects.parameters.Arc`): The arc.
            snap_to_exterior (:obj:`bool`): Whether to snap the arc points to the exterior of the mesh.

        Returns:
            locations: A list of locations, either snapped to exterior of the mesh or along the given arc
        """
        if not arc:
            return []

        # get stations and elevations
        if snap_to_exterior:
            if not self.co_grid:
                return []
            # If dataset geometry is a mesh, snap arcs to the mesh
            return self._compute_snapped_points(arc)
        return [(pt.x, pt.y, 0.0) for pt in arc.get_points(FilterLocation.PT_LOC_ALL)]

    def _retrieve_xms_data_for_arc(self, arc, snap_to_exterior, use_poly_line_extractor):
        """Use xms.sapi to get selected data.

        Args:
            arc (:obj:`xms.data_objects.parameters.Arc`): The arc.
            snap_to_exterior (:obj:`bool`): Whether to snap the arc points to the exterior of the mesh.
            use_poly_line_extractor (:obj:`bool`): True if using the polyline extractor

        Returns:
            (:obj:`bool`): True on success
        """
        if not arc:
            return False

        locations = self._get_arc_locations(arc, snap_to_exterior)
        return self.retrieve_xms_data_for_points(locations, use_poly_line_extractor)

    def _retrieve_raster_data_for_arc(self, arc):
        """Use xms.sapi to get selected data.

        Args:
            arc (:obj:`xms.data_objects.parameters.Arc`): The arc.

        Returns:
            (:obj:`bool`): True on success
        """
        if not arc:
            return False

        arc_pts = arc.get_points(FilterLocation.PT_LOC_ALL)
        return self.retrieve_raster_data_for_points(arc_pts)

    def _select_dataset_with_dialog(self, parent, allow_rasters):
        """Select a dataset from XMS and get the dataset values and geometry UUID.

        Args:
            parent (:obj:`QObject`): The parent object.
            allow_rasters (:obj:`bool`): Allow rasters.

        Returns:
            (:obj:`str`): The selected dataset's UUID.
        """
        tree_types = None
        target_type = xmd.DatasetItem
        if allow_rasters:
            tree_types = ['TI_SFUNC', 'TI_SCALAR_DSET', 'TI_IMAGE']
            target_type = type(None)

        dialog = TreeItemSelectorDlg(
            title='Select Dataset',
            target_type=target_type,
            pe_tree=self.pe_tree,
            previous_selection=self.dataset_uuid,
            show_root=True,
            parent=parent,
            selectable_xms_types=tree_types
        )
        if dialog.exec() == QDialog.DialogCode.Accepted:
            return dialog.get_selected_item_uuid()
        return ''

    def _setup_geometry_tree(self):
        """Get the geometry tree for use when selecting the dataset."""
        if self.query:
            self.pe_tree = self.query.copy_project_tree()
            geom_children = []
            if self.pe_tree:
                for child in self.pe_tree.children:  # Only allow datasets on 2D mesh and 2D scatters
                    if child.item_typename in ['TI_ROOT_2DMESH', 'TI_ROOT_2DSCAT', 'TI_ROOT_GIS']:
                        geom_children.append(child)
            self.pe_tree.children = geom_children
            tree_util.filter_project_explorer(self.pe_tree, DatasetSelector.is_scalar_if_dset)

    @staticmethod
    def _compute_stations(locations):
        """Compute the stations from given tuple of X, Y locations.

        Args:
            locations (:obj:`list[tuple]`): The locations as (x,y) tuples

        Returns:
            (:obj:`list[float]`: The computed stations
        """
        stations = [0.0]
        distance = 0.0
        for i in range(1, len(locations)):
            pt1 = locations[i - 1]
            pt2 = locations[i]
            distance += ((pt1[0] - pt2[0])**2 + (pt1[1] - pt2[1])**2)**0.5
            stations.append(distance)
        return stations

    def _compute_elevations(self, locations, use_poly_line_extractor):
        """Compute elevations for snapped arc along grid.

        Args:
            locations (:obj:`list[tuple]`): The X, Y locations of the points along the arc.
            use_poly_line_extractor (:obj:`bool`): True if using the polyline extractor

        Returns:
            (:obj:`tuple(list)`): Elevations for each arc point, extract locations.
        """
        ugrid = self.co_grid.ugrid
        if use_poly_line_extractor:
            if not self.extractor:
                self.extractor = UGrid2dPolylineDataExtractor(ugrid=ugrid, scalar_location='points')
                point_activity = self.ts_data != self.null_value
                self.extractor.no_data_value = self.null_value
                self.extractor.set_grid_scalars(self.ts_data, point_activity, 'points')
            self.extractor.set_polyline(locations)  # set the line segment on the data extractor
            locations = self.extractor.extract_locations  # get locations where segment intersects ugrid
        else:
            if not self.extractor:
                self.extractor = UGrid2dDataExtractor(ugrid=ugrid)
                point_activity = self.ts_data != self.null_value
                self.extractor.no_data_value = self.null_value
                self.extractor.set_grid_point_scalars(self.ts_data, point_activity, 'points')
            self.extractor.extract_locations = locations
        elevations = list(self.extractor.extract_data())
        return elevations, locations

    def _get_elevations_raster(self, locations):
        """Compute elevations from raster for an arc.

        Args:
            locations (:obj:`list`): The X, Y locations of the points along the arc.

        Returns:
            (:obj:`list`): Elevations for each arc point
        """
        if self.raster_input is None:
            return []

        elevations = []
        for pt in locations:
            column, row = self.raster_input.coordinate_to_pixel(pt.x, pt.y)
            elev = self.raster_input.get_raster_values(column, row, 1, 1)
            if elev is not None:
                elevations.append(elev[0][0])
            else:
                elevations.append(self.null_value)
        return elevations

    def _compute_snapped_points(self, arc):
        """Compute the snapped points for the arc along the grid.

        Args:
            arc (:obj:`xms.data_objects.parameters.Arc`): The arc.

        Returns:
            (:obj:`list[tuple]`): The points of the arc (as x,y,z tuples) snapped to the exterior of the grid.
        """
        if not self.snap:
            self.snap = SnapExteriorArc()
            self.snap.set_grid(grid=self.co_grid, target_cells=False)
        result = self.snap.get_snapped_points(arc)
        locations = result['location']
        return locations

    def _find_arc(self, arc_id):
        """Find arc with given ID in coverage from XMS.

        Args:
            arc_id: The arc ID.

        Returns:
            (:obj:`xms.data_objects.parameters.Arc`): See description
        """
        coverage = self.query.item_with_uuid(self.cov_uuid)
        arcs = coverage.arcs
        for arc in arcs:
            if arc.id == arc_id:
                return arc
        return None
