"""DataHandler class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os
from pathlib import Path
from typing import List, Optional
import uuid

# 2. Third party modules
from geopandas import GeoDataFrame
from shapely import LineString, Point, Polygon

# 3. Aquaveo modules
import xms.api._xmsapi.dmi as xmd
from xms.api.tree import tree_util
from xms.constraint import read_grid_from_file
from xms.core.filesystem import filesystem
from xms.data_objects.parameters import Coverage, FilterLocation, Projection, UGrid
from xms.datasets.dataset_reader import DatasetReader
from xms.gdal.rasters import raster_utils, RasterInput
from xms.tool_core import DataHandler  # noqa I100,I201
from xms.tool_core.coverage_argument import ALLOW_ONLY_COVERAGE_TYPE, ALLOW_ONLY_MODEL_NAME
from xms.tool_core.coverage_reader import CoverageReader
from xms.tool_core.coverage_writer import CoverageWriter
from xms.tool_core.dataset_filters import (
    ALLOW_ONLY_CELL_MAPPED, ALLOW_ONLY_POINT_MAPPED, ALLOW_ONLY_SCALARS, ALLOW_ONLY_TRANSIENT, ALLOW_ONLY_VECTORS
)

# 4. Local modules


class XmsDataHandler(DataHandler):
    """Class to get and send tool data for arguments back and forth to XMS."""
    def __init__(self, query, project_tree=None):
        """Construct an XmsDataHandler object."""
        super(XmsDataHandler, self).__init__()
        self._query = query
        self._coverages = None
        self._grids = None
        self._datasets = None
        self._rasters = None
        self._new_coverages = []
        self._new_grids = []
        self._new_dsets = []
        self._new_rasters = []
        self._uses_file_system = False
        self.output_folder_paths = {}  # no longer used by xmstool >= 6.4.8

        # Cache input datasets by UUID
        self._input_dataset_files = {}  # {uuid: (filename, group_path)}

        if project_tree is None:
            self._project_tree = self._query.project_tree
        else:
            self._project_tree = project_tree

    def get_available_grids(self):
        """Get a list of the available grids.

        Returns:
            (list[str]): A list of the available grids.
        """
        self._load_grids()
        available_grids = list(self._grids.keys())
        return available_grids

    def get_available_datasets(self, filters=None):
        """Returns a list of the available datasets.

        Args:
            filters (Optional[list]: The filters to apply to the returned list.

        Returns:
            (list[str]): A list of the available datasets.
        """
        if self._datasets is None:
            self._load_datasets()

        datasets = []
        for dset_name, tree_item in self._datasets.items():
            included = True
            if filters is not None:
                if ALLOW_ONLY_TRANSIENT in filters and tree_item.num_times < 2:
                    included = False
                elif ALLOW_ONLY_SCALARS in filters and tree_item.num_components != 1:
                    included = False
                elif ALLOW_ONLY_VECTORS in filters and tree_item.num_components < 2:
                    included = False
                elif ALLOW_ONLY_CELL_MAPPED in filters and tree_item.data_location != 'CELL':
                    included = False
                elif ALLOW_ONLY_POINT_MAPPED in filters and tree_item.data_location != 'NODE':
                    included = False
            if included:
                datasets.append(dset_name)
        return datasets

    def get_available_rasters(self):
        """Get a list of the available rasters.

        Returns:
            (list[str]): A list of the available rasters.
        """
        if self._rasters is None:
            self._load_rasters()
        available_rasters = list(self._rasters.keys())
        return available_rasters

    def get_available_coverages(self, filters=None) -> List[str]:
        """Get a list of the available coverages.

        Args:
            filters (Optional[Dict[str, str]]): The filters to apply to the returned list.

        Returns:
            (List[str]): A list of the available coverages.
        """
        if self._coverages is None:
            self._load_coverages()

        coverages = []
        for tree_path, tree_item in self._coverages.items():
            included = True
            if filters is not None:
                if ALLOW_ONLY_MODEL_NAME in filters and tree_item.model_name != filters[ALLOW_ONLY_MODEL_NAME]:
                    included = False
                elif ALLOW_ONLY_COVERAGE_TYPE in filters and \
                        tree_item.coverage_type != filters[ALLOW_ONLY_COVERAGE_TYPE]:
                    included = False
            if included:
                coverages.append(tree_path)
        coverages.sort()
        return coverages

    def get_input_grid(self, tree_path):
        """Get a grid chosen from the list of available grids (self.available_grids).

        Args:
            tree_path (str): Tree item path of the grid to retrieve.

        Returns:
            (xms.constraint.Grid): The grid object.
        """
        self._load_grids()  # ensure grids are loaded
        file_name = self._get_grid_filename(tree_path)
        grid = read_grid_from_file(file_name)
        return grid

    def get_grid_name_from_uuid(self, grid_uuid: str) -> Optional[str]:
        """Get an input grid name from a UUID.

        Args:
            grid_uuid: The grid UUID.

        Returns:
            The grid name or None.
        """
        return self._get_path_from_uuid(grid_uuid, xmd.UGridItem)

    def get_uuid_from_grid_name(self, grid_name: str) -> Optional[str]:
        """Get a UUID from an input grid name.

        Args:
            grid_name: The grid name.

        Returns:
            The grid UUID or None.
        """
        return self._get_uuid_from_path(grid_name, xmd.UGridItem)

    def set_output_grid(self, grid, grid_name, projection, force_ugrid=True) -> None:
        """Set output grid for argument.

        Args:
            grid (Grid): The output grid.
            grid_name (str): The name of the new grid. For testing the grid name is the output
                file path in the "files" folder.
            projection (ProjectionOrStr): The WKT or data_objects Projection of the
                grid.
            force_ugrid (bool): Whether to save the output as an UGrid in SMS.
        """
        if projection is None:  # Use the display projection if not specified
            projection = self._query.display_projection
        elif type(projection) is str:  # Allow WKT or data_object
            projection = Projection(wkt=projection)
        temp_filename = filesystem.temp_filename()
        grid.write_to_file(temp_filename, True)
        ug_uuid = grid.uuid if grid.uuid else str(uuid.uuid4())
        ugrid = UGrid(temp_filename, name=grid_name, uuid=ug_uuid, projection=projection)
        ugrid.force_ugrid = force_ugrid
        self._new_grids.append(ugrid)

    def _load_grids(self):
        """Get the available grids.

        Returns:
            (dict): Dictionary of available grids to tree item.
        """
        if self._grids is not None:
            return

        grid_items = tree_util.descendants_of_type(self._project_tree, xmd.UGridItem)
        self._grids = {}
        for grid in grid_items:
            path = _get_tree_path(grid)
            self._grids[path] = grid

    def _get_grid_filename(self, grid_name, grid_uuid=None):
        """Returns the filename where the grid is saved.

        Args:
            grid_name (str): The grid name.
            grid_uuid (str): UUID of the grid. If provided grid_name is ignored.

        Returns:
            (str): The path to the CoGrid file.
        """
        ugrid_uuid = grid_uuid if grid_uuid is not None else self._grids[grid_name].uuid
        ugrid_item = self._query.item_with_uuid(ugrid_uuid)
        return '' if ugrid_item is None else ugrid_item.cogrid_file

    def get_input_raster(self, raster_name) -> RasterInput:
        """Get a raster chosen from the list of available rasters (self.available_rasters).

        Args:
            raster_name (str): The raster to retrieve.

        Returns:
            (RasterInput): The raster object.
        """
        file_name = self.get_input_raster_file(raster_name)
        return RasterInput(file_name)

    def get_input_raster_file(self, raster_name):
        """Get a raster chosen from the list of available rasters (self.available_rasters).

        Args:
            raster_name (str): The raster to retrieve.

        Returns:
            (str): The raster filename
        """
        file_name = self._get_raster_filename(raster_name)
        return file_name

    def get_raster_name_from_uuid(self, raster_uuid: str) -> Optional[str]:
        """Get an input name from a UUID.

        Args:
            raster_uuid: The raster UUID.

        Returns:
            The input raster name or None.
        """
        if self._rasters is None:
            self._load_rasters()
        for item_name, item_uuid in self._rasters.items():
            if item_uuid == raster_uuid:
                return item_name
        return None

    def get_uuid_from_raster_name(self, raster_name: str) -> Optional[str]:
        """Get a UUID from a raster name.

        Args:
            raster_name: The raster name.

        Returns:
            The UUID or None.
        """
        if self._rasters is None:
            self._load_rasters()
        if raster_name in self._rasters:
            return self._rasters[raster_name]
        return None

    def get_output_raster(self, raster_name, raster_extension: str = '.tif') -> str:
        """Get the path to write output raster to.

        Args:
            raster_name (str): The raster name.
            raster_extension: The file extension to add to the raster_name.

        Returns:
            (str): The raster file path.
        """
        xms_temp_folder = os.environ.get('XMS_PYTHON_APP_TEMP_DIRECTORY', 'unknown')
        output_filename = f'{raster_name}{raster_extension}'
        output_path = os.path.join(xms_temp_folder, output_filename)
        if output_path not in self._new_rasters:
            self._new_rasters.append(output_path)
        return output_path

    def set_output_raster_file(self, raster_file, raster_name, raster_format: str = 'GTiff',
                               raster_extension: str = '.tif'):
        """Set output raster for argument.

        Args:
            raster_file (str): The output raster file.
            raster_name (str): The output raster name in the project explorer.
            raster_format: The short name of the raster format string.  See https://gdal.org/drivers/raster/index.html
                for a list of possible strings.
            raster_extension: The extension to add to the raster_name.
        """
        if raster_file not in self._new_rasters:
            copy_to_path = self.get_output_raster(raster_name, raster_extension)
            if copy_to_path != raster_file:
                raster_utils.copy_raster(raster_file, copy_to_path, raster_format)

    def _get_raster_filename(self, raster_name):
        """Returns the filename the raster.

        Args:
            raster_name (str): The grid name.

        Returns:
            (str): The path to the raster file.
        """
        raster_uuid = self._rasters[raster_name]
        path = self._query.item_with_uuid(raster_uuid)
        return path

    def _load_rasters(self):
        raster_items = tree_util.descendants_of_type(self._project_tree, xms_types=['TI_IMAGE'])
        self._rasters = {}
        for raster in raster_items:
            if self._query.item_with_uuid(raster.uuid):
                path = _get_tree_path(raster)
                self._rasters[path] = raster.uuid

    def get_input_dataset(self, dataset_name):
        """Get a dataset chosen from the list of available datasets (self.available_datasets).

        Args:
            dataset_name (str): The dataset to retrieve.

        Returns:
            (xms.datasets.dataset_reader.DatasetReader): The dataset object.
        """
        try:
            return self._get_dataset(dataset_name)
        except Exception:
            return None

    def get_input_dataset_grid(self, dataset_name):
        """Get the geometry of a dataset chosen from the list of available datasets.

         Available datasets come from DataHandler.available_datasets.

        Args:
            dataset_name (str): The dataset to retrieve.

        Returns:
            (xms.constraint.Grid): The dataset's constrained grid.
        """
        dset_item = self._datasets[dataset_name]
        ugrid_item = tree_util.ancestor_of_type(dset_item, xmd.UGridItem)
        grid_filename = self._get_grid_filename(grid_name='', grid_uuid=ugrid_item.uuid)
        return read_grid_from_file(grid_filename)

    def get_dataset_name_from_uuid(self, dataset_uuid: str) -> Optional[str]:
        """Get a dataset from a UUID.

        Args:
            dataset_uuid: The dataset UUID.

        Returns:
            The dataset name or None.
        """
        return self._get_path_from_uuid(dataset_uuid, xmd.DatasetItem)

    def get_uuid_from_dataset_name(self, dataset_name: str) -> Optional[str]:
        """Get a UUID from a dataset.

        Args:
            dataset_name: The dataset name.

        Returns:
            The dataset UUID or None.
        """
        return self._get_uuid_from_path(dataset_name, xmd.DatasetItem)

    def set_output_dataset(self, dataset):
        """Set output dataset for argument.

        Args:
            dataset (xms.datasets.dataset_writer.DatasetWriter): The output dataset.
        """
        self._new_dsets.append(dataset)

    def _get_dataset(self, dataset_name):
        """Get data_objects dataset for named dataset.

        Args:
            dataset_name (str): The dataset name.

        Returns:
            (xms.datasets.dataset_reader.DatasetReader): The dataset.
        """
        dataset_uuid = self._datasets[dataset_name].uuid
        if dataset_uuid and dataset_uuid not in self._input_dataset_files:
            # Cache the dataset's filename and H5 group path to avoid redundant XMS queries.
            dataset_item = self._query.item_with_uuid(dataset_uuid)
            if dataset_item is not None:
                self._input_dataset_files[dataset_uuid] = (dataset_item.h5_filename, dataset_item.group_path)
        filename, group_path = self._input_dataset_files[dataset_uuid]
        return DatasetReader(h5_filename=filename, group_path=group_path)

    def _load_datasets(self):
        """Load the available datasets.

        Returns:
            (dict): Dictionary of available datasets to tree item.
        """
        app_name = os.environ.get('XMS_PYTHON_APP_NAME')
        dataset_items = tree_util.descendants_of_type(self._project_tree, xmd.DatasetItem)
        self._datasets = {}
        for dataset in dataset_items:
            path = _get_tree_path(dataset)
            if app_name != 'GMS' or tree_util.ancestor_of_type(dataset, xmd.UGridItem):
                self._datasets[path] = dataset

    def get_input_coverage(self, coverage_name: str) -> GeoDataFrame | None:
        """Get a coverage chosen from the list of available coverages (self.available_coverages).

        Args:
            coverage_name (str): The coverage to retrieve.

        Returns:
            (GeoDataFrame): The Coverage object.
        """
        if coverage_name not in self._coverages:
            return None
        coverage_uuid = self._coverages[coverage_name].uuid
        coverage = self._query.item_with_uuid(coverage_uuid)
        return convert_to_geodataframe(coverage, self._query.display_projection.well_known_text)

    def get_input_coverage_file(self, coverage_name: str) -> str:
        """Get a coverage file chosen from the list of available coverages (self.available_coverages).

        Args:
            coverage_name: The coverage to retrieve.

        Returns:
            The coverage file path.
        """
        coverage = self.get_input_coverage(coverage_name)
        if coverage is not None:
            file_path, _group_path = coverage.filename_and_group
            return file_path
        return ''

    def get_coverage_name_from_uuid(self, coverage_uuid: str) -> Optional[str]:
        """Get an input coverage from a UUID.

        Args:
            coverage_uuid: The coverage UUID.

        Returns:
            The coverage name or None.
        """
        return self._get_path_from_uuid(coverage_uuid, xmd.CoverageItem)

    def get_uuid_from_coverage_name(self, coverage_name: str) -> Optional[str]:
        """Get a UUID from an input coverage.

        Args:
            coverage_name: The coverage name.

        Returns:
            The UUID or None.
        """
        return self._get_uuid_from_path(coverage_name, xmd.CoverageItem)

    def set_output_coverage(self, coverage: GeoDataFrame, coverage_name: str) -> None:
        """Set output grid for argument.

        Args:
            coverage (GeoDataFrame): The output coverage.
            coverage_name (str): The coverage name.
        """
        temp_filename = filesystem.temp_filename(suffix='.h5')
        writer = CoverageWriter(temp_filename)
        writer.write(coverage)
        new_coverage = Coverage(temp_filename)
        self._new_coverages.append(new_coverage)

    def _load_coverages(self):
        """Load the available coverages."""
        coverages = tree_util.descendants_of_type(self._project_tree, xmd.CoverageItem)
        self._coverages = {}
        for coverage in coverages:
            path = _get_tree_path(coverage)
            self._coverages[path] = coverage

    def send_output_to_xms(self):
        """Send output data to XMS.
        """
        if self._query is None:
            return

        for coverage in self._new_coverages:
            self._query.add_coverage(coverage)
        for grid in self._new_grids:
            self._query.add_ugrid(grid)
        for dset in self._new_dsets:
            self._query.add_dataset(dset, folder_path=self.output_folder_paths.get(dset.uuid))
        for raster in self._new_rasters:
            self._query.add_raster(raster)

    def get_default_wkt(self):
        """Get the WKT of the default coordinate system.

        Returns:
            (str): The WKT of the default coordinate system.
        """
        return self._query.display_projection.well_known_text

    def get_vertical_datum(self):
        """Get the vertical datum of the default coordinate system.

        Returns:
            (str): The vertical datum of the default coordinate system.
        """
        return self._query.display_projection.vertical_datum

    def get_vertical_units(self):
        """Get the vertical units of the default coordinate system.

        Returns:
            (str): The vertical units of the default coordinate system.
        """
        return self._query.display_projection.vertical_units

    def get_output_files(self):
        """Get the output files to be sent (for testing only).

        Returns:
            (List[str]): A list of file paths.
        """
        files = []
        for coverage in self._new_coverages:
            files.append(coverage.name)
        for grid in self._new_grids:
            files.append(grid.name)
        for dset in self._new_dsets:
            files.append(dset.name)
        for raster in self._new_rasters:
            files.append(os.path.basename(raster))
        return files

    def _get_path_from_uuid(self, item_uuid: str, class_: object) -> Optional[str]:
        """Get path to tree item from a UUID.

        Args:
            item_uuid: The UUID of the tree item.

        Returns:
            The path to the item or None if not found.
        """
        path = tree_util.build_tree_path(self._project_tree, item_uuid)
        if path:
            tree_item = tree_util.item_from_path(self._project_tree, path)
            if tree_item is not None and type(tree_item.data) is class_:
                return _get_tree_path(tree_item)
        return None

    def _get_uuid_from_path(self, item_path: str, class_: object) -> Optional[str]:
        """Get a UUID from the path to tree item.

        Args:
            item_path: The path to the tree item.

        Returns:
            The item UUID or None if not found.
        """
        try:
            full_path = self._project_tree.name + '/' + item_path
            tree_item = tree_util.item_from_path(self._project_tree, full_path)
            if tree_item is not None and type(tree_item.data) is class_:
                return tree_item.uuid
            return None
        except TypeError:
            return None


def _get_tree_path(tree_item):
    """Build a path to a UGrid or dataset.

    Args:
        tree_item (xms.guipy.tree.tree_node.TreeNode): Tree node of the dataset.

    Returns:
        (str): Path to XMS tree to the dataset.
    """
    path = [tree_item.name]
    tree_item = tree_item.parent
    while tree_item is not None:
        path.append(tree_item.name)
        tree_item = tree_item.parent
    path.reverse()
    path = path[1:]  # Trim off the project node
    return '/'.join(path)


def convert_to_geodataframe(coverage: Coverage, default_wkt: str, atts: dict = None) -> GeoDataFrame | None:
    """Converts a coverage to a Pandas GeoDataFrame.

    Args:
        coverage (Coverage): The coverage.
        default_wkt (str): The default WKT for the coverage's projection
        atts (dict): The dict specifying the attributes for the coverage

    Returns:
        (GeoDataFrame): The coverage converted to a GeoDataFrame.
    """
    if coverage is None:
        return None
    # Polygons
    filename = coverage.filename_and_group[0]
    temp_file_path = filesystem.temp_filename()
    write_file = True
    if Path(filename).is_file():
        filesystem.copyfile(filename, temp_file_path)
        if atts is None:
            reader = CoverageReader(temp_file_path)
            atts = reader.read_attributes()
        write_file = False
    if atts is None:
        atts = dict()
    polygons = coverage.polygons
    polygon_list = []
    poly_ids = []
    geometry_types = ['Polygon'] * len(polygons)
    polygon_arc_ids = []
    polygon_arc_directions = []
    interior_arc_ids = []
    interior_arc_directions = []
    start_node = [-1] * len(polygons)
    end_node = [-1] * len(polygons)
    for poly in polygons:
        # Create the outer ring of the polygon
        points = []
        poly_arc_ids = []
        hole_arc_ids = []
        polygon_arc_directions.append(poly.arc_directions.copy())
        for idx, poly_arc in enumerate(poly.arcs):
            arc_points = poly_arc.get_points(FilterLocation.PT_LOC_ALL)
            poly_arc_ids.append(poly_arc.id)
            if poly.arc_directions[idx]:
                arc_points.reverse()
            for pt_idx, pt in enumerate(arc_points):
                if idx == 0 or pt_idx > 0:
                    points.append((pt.x, pt.y, pt.z))
        # Create any interior rings (if there)
        holes = []
        interior_arc_directions.append(poly.interior_arc_directions.copy())
        for hole_poly_idx, hole in enumerate(poly.interior_arcs):
            inner_ring = []
            hole_ids = []
            for hole_idx, poly_arc in enumerate(hole):
                arc_points = poly_arc.get_points(FilterLocation.PT_LOC_ALL)
                hole_ids.append(poly_arc.id)
                if not poly.interior_arc_directions[hole_poly_idx][hole_idx]:
                    arc_points.reverse()
                for pt_idx, pt in enumerate(arc_points):
                    if hole_idx == 0 or pt_idx > 0:
                        inner_ring.append((pt.x, pt.y, pt.z))
            hole_arc_ids.append(hole_ids)
            holes.append(inner_ring)
        polygon_arc_ids.append(poly_arc_ids)
        interior_arc_ids.append(hole_arc_ids)
        polygon_list.append(Polygon(points, holes))
        poly_ids.append(poly.id)
    # Arcs
    arc_list = []
    arc_ids = []
    arcs = coverage.arcs
    geometry_types.extend(['Arc'] * len(arcs))
    empty_list = [[] for _ in range(len(arcs))]
    polygon_arc_ids.extend(empty_list.copy())
    polygon_arc_directions.extend(empty_list.copy())
    interior_arc_ids.extend(empty_list.copy())
    interior_arc_directions.extend(empty_list.copy())
    for arc in arcs:
        # Create a line string, and convert to the raster's spatial reference
        points = [(arc.start_node.x, arc.start_node.y, arc.start_node.z)]
        for vertex in arc.vertices:
            points.append((vertex.x, vertex.y, vertex.z))
        points.append((arc.end_node.x, arc.end_node.y, arc.end_node.z))
        arc_list.append(LineString(points))
        arc_ids.append(arc.id)
        start_node.append(arc.start_node.id)
        end_node.append(arc.end_node.id)
    # Points
    point_list = []
    point_ids = []
    for i in range(3):
        points = []
        if i == 0:
            points = coverage.get_points(FilterLocation.PT_LOC_DISJOINT)
            geometry_types.extend(['Point'] * len(points))
        elif i == 1:
            points = coverage.get_points(FilterLocation.PT_LOC_CORNER)
            geometry_types.extend(['Node'] * len(points))
        elif i == 2:
            points = coverage.get_points(FilterLocation.PT_LOC_MID)
            geometry_types.extend(['Vertex'] * len(points))
        empty_list = [[] for _ in range(len(points))]
        polygon_arc_ids.extend(empty_list.copy())
        polygon_arc_directions.extend(empty_list.copy())
        interior_arc_ids.extend(empty_list.copy())
        interior_arc_directions.extend(empty_list.copy())
        start_node.extend([-1] * len(points))
        end_node.extend([-1] * len(points))
        for point in points:
            point_list.append(Point(point.x, point.y, point.z))
            point_ids.append(point.id)
    json_strings = [''] * len(polygon_arc_ids)
    crs = ''
    if coverage.projection is not None:
        crs = coverage.projection.well_known_text
    if not crs:
        crs = default_wkt
    gdf = GeoDataFrame({'id': poly_ids + arc_ids + point_ids, 'geometry_types': geometry_types,
                        'geometry': polygon_list + arc_list + point_list, 'polygon_arc_ids': polygon_arc_ids,
                        'polygon_arc_directions': polygon_arc_directions, 'interior_arc_ids': interior_arc_ids,
                        'interior_arc_directions': interior_arc_directions, 'start_node': start_node,
                        'end_node': end_node, 'attributes': json_strings}, crs=crs if len(polygon_arc_ids) else None)
    gdf.attrs['name'] = coverage.name
    gdf.attrs['uuid'] = coverage.uuid
    gdf.attrs['attributes'] = atts
    if write_file:
        writer = CoverageWriter(temp_file_path)
        writer.write(gdf)
    gdf.attrs['filename'] = temp_file_path
    return gdf
