"""Code to handle mapping from a coverage with a TIN."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import math
from pathlib import Path

# 2. Third party modules
from rtree import index

# 3. Aquaveo modules
from xms.constraint import Grid, read_grid_from_file
from xms.coverage.xy.xy_series import XySeries
from xms.datasets.dataset_reader import DatasetReader
from xms.grid.ugrid import UGrid
from xms.interp.interpolate import InterpLinear

# 4. Local modules
from xms.mf6.components import dis_builder
from xms.mf6.data import time_util
from xms.mf6.data.base_file_data import BaseFileData
from xms.mf6.file_io import csv_dataset_reader

# Type aliases
Xyz = tuple[float, float, float]


class TinMapper:
    """Handles interpolating from a tin to grid cells.

    Caches results for speed when the same cell xyz is given.
    """
    def __init__(self, tin_filepath: str, cogrid: Grid, ugrid: UGrid, package: BaseFileData):
        """Initializer.

        Args:
            tin_filepath: Filepath to .xmc file defining the TIN (as a UGrid).
            cogrid: The constrained grid.
            ugrid: The ugrid of the cogrid.
            package: The package being mapped to.
        """
        self._tin_filepath: str = tin_filepath
        self._cogrid = cogrid
        self._ugrid = ugrid
        self._package = package

        self._tin_ugrid = None
        self._rtree_2d = None  # A 2d (x,y) rtree used to speed up locating grid points
        self._cell_centers_2d = dis_builder.get_cell_centers2d(self._cogrid, self._ugrid)
        self._cell_results: dict[Xyz, float | XySeries] = {}  # Cached results
        self._interpolator: InterpLinear | None = None
        self._dataset: DatasetReader | None = None

    def interpolate(self, cellidx: int) -> float | XySeries:
        """Return the value (a constant, or a time series) at the cell center by interpolating from the TIN.

        Args:
            cellidx: 0-based cell index of a cell in the MODFLOW grid.

        Returns:
            See description.
        """
        cell_center = tuple(self._cell_centers_2d[cellidx])
        if cell_center not in self._cell_results:
            result = self._interpolate(cell_center)
            self._cell_results[cell_center] = result
        return self._cell_results[cell_center]

    def _interpolate(self, cell_center: Xyz) -> float | XySeries:
        """Interpolate from tin to grid cell center, returning a float, if tin dataset has 1 timestep, else an XySeries.

        Args:
            cell_center: x,y,z of the cell (with z always 0.0).

        Returns:
            See description.
        """
        interpolator = self._get_interpolator()
        dataset = self._get_dataset()
        if len(dataset.times) > 1:
            result = self._interpolate_transient(cell_center, interpolator, dataset)
        else:
            result = self._interpolate_steady_state(cell_center, interpolator, dataset)
        return result

    def _get_interpolator(self) -> InterpLinear:
        """Return the interpolator object for the tin, creating and storing it if it doesn't exist yet.

        Returns:
            See description.
        """
        if self._interpolator is None:
            self._interpolator = self._init_interpolator(self._tin_filepath)
        return self._interpolator

    def _get_dataset(self) -> DatasetReader:
        """Return the DatasetReader object for the tin, creating and storing it if it doesn't exist yet.

        Returns:
            See description.
        """
        if self._dataset is None:
            self._dataset = self._init_dataset(self._tin_filepath)
        return self._dataset

    def _interpolate_steady_state(self, cell_center: Xyz, interpolator: InterpLinear, dataset: DatasetReader) -> float:
        """Interpolate from tin to grid cell center, and return a float.

        If cell_center is outside the tin, the value at the closest tin vertex is returned.

        Args:
            cell_center: x,y,z of the cell (with z always 0.0).
            interpolator: The interpolator for the tin.
            dataset: The tin dataset.

        Returns:
            See description.
        """
        interpolator.scalars = dataset.values[0]
        value = interpolator.interpolate_to_point(cell_center)
        if math.isnan(value):
            closest_vertex = self._closest_vertex(cell_center)
            value = interpolator.scalars[closest_vertex].item()  # item() converts from np float32 to python float
        return value

    def _interpolate_transient(self, cell_center: Xyz, interpolator: InterpLinear, dataset: DatasetReader) -> XySeries:
        """Interpolate from the tin to the cell center, and return a XySeries.

        If cell_center is outside the tin, the values at the closest tin vertex are returned.

        Args:
            cell_center: cell center x,y,z (z is always 0.0).
            interpolator: The interpolator for the tin.
            dataset: The tin dataset.

        Returns:
            See description.
        """
        dataset_times = time_util.dataset_times(dataset)

        # Get values at the cell center over time
        values = []
        for ts_idx in range(len(dataset_times)):
            interpolator.scalars = dataset.values[ts_idx]
            value = interpolator.interpolate_to_point(cell_center)
            if math.isnan(value):
                closest_vertex = self._closest_vertex(cell_center)
                value = interpolator.scalars[closest_vertex].item()  # item() converts from np float32 to python float
            values.append(float(value))  # Avoid numpy types. They don't work when we save them to the sqlite database

        # Create and return xy series
        return XySeries(x=dataset_times, y=values)

    def _closest_vertex(self, cell_center) -> int:
        """Return the closest tin vertex to the cell center.

        Args:
            cell_center: cell center x,y,z (z is always 0.0).

        Returns:
            See description.
        """
        if not self._rtree_2d:
            self._rtree_2d = _build_2d_rtree(self._tin_ugrid.locations)
        return list(self._rtree_2d.nearest((cell_center[0], cell_center[1]), 1))[0]

    def _init_interpolator(self, tin_filepath: str) -> InterpLinear:
        """Initializes linear interpolation for the xmc file, returning the interpolator and the dataset.

        Args:
            tin_filepath:

        Returns:
            See description.
        """
        tin_cogrid = read_grid_from_file(tin_filepath)
        assert tin_cogrid.check_all_cells_are_of_type(UGrid.cell_type_enum.TRIANGLE)
        self._tin_ugrid = tin_cogrid.ugrid
        triangles = _get_triangles(self._tin_ugrid)
        return InterpLinear(points=self._tin_ugrid.locations, triangles=triangles)

    def _init_dataset(self, tin_filepath: str) -> DatasetReader:
        # Get dataset saved as .csv file. The filepath should be identical to the .xmc file but with a .csv suffix
        csv_filepath = Path(tin_filepath).with_suffix('.csv')
        dataset_reader = csv_dataset_reader.dataset_from_csv(csv_filepath)
        return dataset_reader


def _get_triangles(ugrid) -> list[int]:
    """Return a list of points that define the triangles.

    Args:
        ugrid:

    Returns:
        See description.
    """
    cellstream = ugrid.cellstream
    # cellstream will be like 5, 3, 7, 1, 4, 5, 3, 4, 6, 7..., where 5 is the number of items in the cell stream for
    # the cell, and 3 is the cell type
    triangles = []
    i = 0
    while i < len(cellstream):
        triangles.append(cellstream[i + 2])
        triangles.append(cellstream[i + 3])
        triangles.append(cellstream[i + 4])
        i += 5
    return triangles


def _rtree_2d_insert_generator(locations):
    """This generator function is supposed to be a faster way to populate the rtree?

    https://rtree.readthedocs.io/en/latest/performance.html#use-stream-loading
    """
    for i, location in enumerate(locations):
        yield i, (location[0], location[1], location[0], location[1]), i


def _build_2d_rtree(locations):
    """Builds an rtree using a generator function which is supposed to be faster.

    https://rtree.readthedocs.io/en/latest/performance.html#use-stream-loading
    """
    p = index.Property()
    p.dimension = 2
    return index.Index(_rtree_2d_insert_generator(locations), properties=p)
