"""Module for class TidalMapper."""

__copyright__ = "(C) Copyright Aquaveo 2023"
__license__ = "All rights reserved"
__all__ = ['TidalMapper']

# 1. Standard Python modules
from logging import Logger
from pathlib import Path
from typing import Optional, Sequence

# 2. Third party modules
import numpy as np
import xarray as xr

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.constraint import read_grid_from_file
from xms.datasets.dataset_reader import DatasetReader
from xms.grid.ugrid.ugrid import UGrid
from xms.guipy.dialogs.feedback_thread import ExpectedError
from xms.interp.interpolate import InterpIdw, InterpLinear
from xms.tides.data.tidal_data import TidalData, USER_DEFINED_INDEX
from xms.tides.data.tidal_extractor import TidalExtractor

# 4. Local modules
from xms.schism.external.project import make_geographic


class TidalMapper:
    """TidalMapper class."""
    def __init__(
        self, query: Query, logger: Logger, tidal_data: str | Path | TidalData, ocean_node_ids: Sequence[int],
        domain: UGrid
    ):
        """
        Initialize the mapper.

        Args:
            query: Interprocess communication object.
            logger: Where to log messages to.
            tidal_data: A TidalData or a path to the mainfile used to construct one.
            ocean_node_ids: IDs of nodes to find the tidal data for.
            domain: The domain that ocean_node_ids applies to.
        """
        self._query = query
        if isinstance(tidal_data, Path):
            tidal_data = str(tidal_data)
        if isinstance(tidal_data, str):
            tidal_data = TidalData(tidal_data)
        self._tidal_data = tidal_data
        self._node_ids = ocean_node_ids
        self._domain = domain
        self._log = logger

        self._dsets: dict[str, DatasetReader] = {}  # uuid -> dataset
        self._geoms: dict[str, UGrid] = {}  # uuid -> ugrid
        self._interpolators: dict[tuple[str, str], InterpLinear] = {}  # uuid, location -> interpolator

        self.elevation: Optional[xr.Dataset] = None
        self.velocity: Optional[xr.Dataset] = None
        self.properties: Optional[xr.Dataset] = None
        self.source: int = -1  # One of the constants from xms.tides.data.tidal_data. ADCIRC_INDEX and friends.

    def _get_dset(self, uuid: str) -> DatasetReader:
        """
        Get a dataset.

        Args:
            uuid: UUID of the dataset to get.

        Returns:
            The dataset.
        """
        if uuid not in self._dsets:
            self._log.info(f'Retrieving dataset: {uuid}')
            dset: DatasetReader = self._query.item_with_uuid(uuid)
            self._dsets[uuid] = dset
        return self._dsets[uuid]

    def _get_geom(self, uuid: str) -> UGrid:
        """
        Get a geometry.

        Args:
            uuid: UUID of the geometry to get.

        Returns:
            The geometry.
        """
        if uuid not in self._geoms:
            self._log.info(f'Retrieving geometry: {uuid}')
            geom = self._query.item_with_uuid(uuid)
            co_grid = read_grid_from_file(geom.cogrid_file)
            ugrid = co_grid.ugrid
            self._geoms[uuid] = ugrid
        return self._geoms[uuid]

    def _get_interpolator(self, uuid: str, location: str) -> InterpLinear:
        """
        Get an interpolator for a dataset.

        Args:
            uuid: UUID of the dataset to get an interpolator for.
            location: Where on the geometry the data is. 'points' or 'cells'.
        """
        if (uuid, location) not in self._interpolators:
            self._log.info(f'Building interpolator for dataset {uuid} on {location}')
            ugrid = self._get_geom(uuid)
            if location == 'cells':
                pts = [ugrid.get_cell_centroid(cell)[1] for cell in range(ugrid.cell_count)]
            else:
                pts = ugrid.locations
            self._interpolators[(uuid, location)] = InterpLinear(points=pts)
        return self._interpolators[(uuid, location)]

    def map(self):
        """Map the tides."""
        self._get_properties()

        self.source = self._tidal_data.info.attrs['source']

        if self.source == USER_DEFINED_INDEX:
            self._get_user_dataset_uuids()
            self._extract_user()
        else:
            self._extract_database()

    def _get_user_dataset_uuids(self):
        """Get the UUIDs of needed datasets."""
        elevation_names = ['Name', 'Phase', 'Amplitude']
        self.elevation = self._tidal_data.user[elevation_names]

        velocity_names = [
            'Name', 'Force Velocity', 'Velocity Amplitude X', 'Velocity Amplitude Y', 'Velocity Phase X',
            'Velocity Phase Y'
        ]
        velocity = self._tidal_data.user[velocity_names]
        velocity = velocity.where(velocity['Force Velocity'] == 1, drop=True)
        velocity_names.remove('Force Velocity')
        self.velocity = velocity[velocity_names]

    def _get_properties(self):
        """
        Get tidal properties.

        These are global things, like the name, nodal factor, and equilibrium argument.
        """
        extractor = TidalExtractor(self._tidal_data)
        properties = extractor.get_constituent_properties()
        filtered = properties[['frequency', 'nodal_factor', 'equilibrium_argument']]
        renamed = filtered.rename({'con': 'name', 'nodal_factor': 'factor', 'equilibrium_argument': 'argument'})
        self.properties = renamed

    def _extract_user(self):
        """Extract user-specified tides."""
        ocean_locs = list(self._domain.get_points_locations([node_id - 1 for node_id in self._node_ids]))
        interps = {}
        self._extract_user_elevation(ocean_locs, interps)
        self._extract_user_velocity(ocean_locs, interps)

    def _extract_user_elevation(self, nodes: list[int], interps: dict):
        """
        Extract user-specified elevation tides.

        Args:
            nodes: IDs of nodes to get data for.
            interps: Interpolators
        """
        all_amps = self._extract_data(self.elevation['Amplitude'], nodes, interps)
        all_phases = self._extract_data(self.elevation['Phase'], nodes, interps)

        if len(self.elevation['Name']) > 0:
            cons = self.elevation['Name'].values
        else:
            cons = np.empty((0, ), dtype=object)
        if len(self._node_ids) > 0:
            node_ids = self._node_ids
        else:
            node_ids = np.empty((0, ), dtype=int)
        tidal_coords = {
            'name': cons,
            'node': node_ids,
        }
        tidal_data = {
            'amplitude': (('name', 'node'), all_amps),
            'phase': (('name', 'node'), all_phases),
        }
        self.elevation = xr.Dataset(data_vars=tidal_data, coords=tidal_coords).fillna(0.0)

    def _extract_user_velocity(self, nodes: Sequence[int], interps: dict):
        """
        Extract user-specified velocity tides.

        Args:
            nodes: IDs of nodes to get data for.
            interps: Interpolators
        """
        all_phase_x = self._extract_data(self.velocity['Velocity Phase X'], nodes, interps)
        all_phase_y = self._extract_data(self.velocity['Velocity Phase Y'], nodes, interps)
        all_amp_x = self._extract_data(self.velocity['Velocity Amplitude X'], nodes, interps)
        all_amp_y = self._extract_data(self.velocity['Velocity Amplitude Y'], nodes, interps)
        names = self.velocity['Name'].values

        tidal_coords = {'name': names, 'node': self._node_ids if len(names) > 0 else np.empty((0, ))}
        tidal_data = {
            'phase_x': (('name', 'node'), all_phase_x),
            'phase_y': (('name', 'node'), all_phase_y),
            'amp_x': (('name', 'node'), all_amp_x),
            'amp_y': (('name', 'node'), all_amp_y),
        }
        self.velocity = xr.Dataset(data_vars=tidal_data, coords=tidal_coords).fillna(0.0)

    def _extract_data(self, uuids: xr.DataArray, nodes: Sequence[int], interps: dict):
        """
        Extract tidal data from a dataset.

        Args:
            uuids: Array of uuids of datasets required.
            nodes: IDs of nodes to get data for.
            interps: Interpolators.

        Returns:
            List of datasets.
        """
        dset_list = []
        for uuid in uuids.values:
            if uuid:
                reader = self._get_dset(uuid)
                interpolator = self._get_interpolator(reader.geom_uuid, reader.location)
                dset_list.append(linear_interp_with_idw_extrap(interpolator, nodes, reader, interps))
            else:
                raise ExpectedError('Tidal dataset unspecified')
        return dset_list if dset_list else np.empty((0, 0))

    def _extract_database(self):
        """Extract tidal data from a database."""
        self._log.info('Extracting tidal data from database...')
        success, domain = make_geographic(self._query.display_projection, self._domain)
        if not success:
            raise ExpectedError('Unable to extract from tidal database in local projection')

        extractor = TidalExtractor(self._tidal_data)
        # TidalExtractor.get_amplitude_and_phase wants points in x,y order, but database extraction requires geographic
        # projection, which gives them to us in lat,lon order, so we have to reverse them.
        ocean_locs = [(loc[1], loc[0]) for loc in domain.ugrid.get_points_locations(self._node_ids)]
        constituent_datasets = extractor.get_amplitude_and_phase(ocean_locs, self._node_ids)
        constituent_datasets = constituent_datasets.fillna(0)
        self.elevation = constituent_datasets.rename({'con': 'name', 'node_id': 'node'})

        empty = np.ndarray((0, 0), dtype=np.float64)
        tidal_coords = {
            'name': np.array([], dtype=np.object_),
            'node': np.array([], dtype=np.int_),
        }
        tidal_data = {
            'phase_x': (('name', 'node'), empty),
            'phase_y': (('name', 'node'), empty),
            'amp_x': (('name', 'node'), empty),
            'amp_y': (('name', 'node'), empty),
        }
        self.velocity = xr.Dataset(data_vars=tidal_data, coords=tidal_coords)


def linear_interp_with_idw_extrap(
    linear_interper: InterpLinear, to_points: Sequence[tuple[float, float, float]], dset_reader: DatasetReader,
    idw_interpers: dict[str, InterpIdw]
):
    """Interpolate a dataset to to_points using linear interpolation and use IDW for extrapolated points.

    Args:
        linear_interper: Interpolator used for linear interpolation.
        to_points: List of x,y,z coordinates to interpolate data to. Locations where you want the data for.
        dset_reader: The dataset with values to interpolate. The data you actually have.
        idw_interpers: Dictionary of IDW interpolators keyed by geom UUID. If linear interpolation results
            in extrapolated values, this dictionary will be checked for an existing IDW interpolator for the geometry.
            If no IDW interpolator present, one will be created.

    Returns:
        (:obj:`list`): interp_vals with extrapolated values replaced with IDW values
    """
    linear_interper.scalars = dset_reader.values[0]
    interp_vals = linear_interper.interpolate_to_points(to_points)
    extrap_pt_idxs = linear_interper.extrapolation_point_indexes
    if len(extrap_pt_idxs) > 0:  # Have extrapolated values, use IDW for those points.
        interp_vals = list(interp_vals)  # Convert immutable tuple to mutable list
        geom_uuid = dset_reader.geom_uuid
        if geom_uuid not in idw_interpers:
            # Set up an IDW interpolator if linear interpolation resulted in extrapolated points.
            idw_interpers[geom_uuid] = InterpIdw(
                points=linear_interper.points, nodal_function='constant', number_nearest_points=2
            )
            idw_interpers[geom_uuid].set_search_options(2, False)

        idw_interp = idw_interpers[geom_uuid]
        idw_interp.scalars = linear_interper.scalars
        idw_interp_vals = idw_interp.interpolate_to_points([to_points[extrap_pt] for extrap_pt in extrap_pt_idxs])
        for extrap_idx, extrap_pt in enumerate(extrap_pt_idxs):
            interp_vals[extrap_pt] = idw_interp_vals[extrap_idx]
    return interp_vals
