"""SmoothDatasetsByNeighborTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import math

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.constraint.ugrid_activity import CellToPointActivityCalculator
from xms.tool_core import ALLOW_ONLY_POINT_MAPPED, ALLOW_ONLY_SCALARS, IoDirection, Tool

# 4. Local modules
from .. utilities.dataset_tool import set_builder_activity_flags, smooth_get_ugrid

ARG_INPUT_DSET = 0
ARG_NUM_LEVELS = 1
ARG_INTERP_OPTION = 2
ARG_WEIGHT = 3
ARG_INPUT_LND = 4
ARG_OUTPUT_DATASET = 5


class SmoothDatasetsByNeighborTool(Tool):
    """Tool to smooth datasets by averaging nodal neighbors."""

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Smooth Datasets by Neighbor')
        self._dataset_reader = None
        self._dataset_builder = None
        self._default_dset = None
        self._lnd = None

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.dataset_argument(name='input_dataset', description='Input dataset',
                                  filters=[ALLOW_ONLY_SCALARS, ALLOW_ONLY_POINT_MAPPED]),
            self.string_argument(name='number_of_levels', description='Number of levels', choices=['1', '2'],
                                 value='1'),
            self.string_argument(name='interpolation_method', description='Interpolation method',
                                 choices=['Average', 'IDW'], value='Average'),
            self.float_argument(name='Weight', description='Weight of nodal neighbors', value=0.5, min_value=0.0,
                                max_value=1.0),
            self.dataset_argument(name='locked_nodes_dataset', description='Locked nodes dataset (optional)',
                                  optional=True, filters=[ALLOW_ONLY_SCALARS, ALLOW_ONLY_POINT_MAPPED]),
            self.dataset_argument(name='default_name', description='Output dataset', value="new dataset",
                                  io_direction=IoDirection.OUTPUT),
        ]
        self.enable_arguments(arguments)
        return arguments

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}

        # Validate input datasets
        self._dataset_reader = self._validate_input_dataset(arguments[ARG_INPUT_DSET], errors)

        # Subset mask dataset
        if arguments[ARG_INPUT_LND].text_value:
            self._lnd = self._validate_input_dataset(arguments[ARG_INPUT_LND], errors)
            dsr = self._dataset_reader
            if self._lnd and dsr and self._lnd.num_values != dsr.num_values:
                msg = 'Number of values in locked nodes dataset must match number of values in input dataset.'
                errors[arguments[ARG_INPUT_LND].name] = msg

        return errors

    def enable_arguments(self, arguments):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.
        """
        weight_disabled = arguments[ARG_INTERP_OPTION].text_value == 'IDW'
        arguments[ARG_WEIGHT].hide = weight_disabled  # hide Weight arg if IDW

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        dset_ugrid = smooth_get_ugrid(self, arguments[ARG_INPUT_DSET])

        # Extract the component, time step at a time
        self._setup_output_dataset_builders(arguments)
        time_count = len(self._dataset_reader.times)
        num_points = dset_ugrid.point_count
        num_levels = int(arguments[ARG_NUM_LEVELS].value)
        simple_average = arguments[ARG_INTERP_OPTION].text_value == 'Average'
        weight = float(arguments[ARG_WEIGHT].value)
        point_weight = 1.0 - weight
        locations = dset_ugrid.locations

        # Set up an activity array based on the optional subset mask
        if self._lnd:
            lnd = self._lnd.values[0]
        else:
            lnd = np.zeros(num_points)

        # Set up activity calculator
        if self._dataset_reader.activity and self._dataset_reader.values:
            if self._dataset_reader.values.shape != self._dataset_reader.activity.shape:
                # Nodal dataset values with cell activity
                self._dataset_reader.activity_calculator = CellToPointActivityCalculator(dset_ugrid)
                self._dataset_builder.activity_calculator = CellToPointActivityCalculator(dset_ugrid)

        for tsidx in range(time_count):
            self.logger.info(f'Processing time step {tsidx + 1} of {time_count}...')
            data, activity = self._dataset_reader.timestep_with_activity(tsidx)
            # Don't consider activity if it cell-based, just copy the array to the output file.
            check_activity = activity is not None and len(activity) == len(locations)
            smoothed_data = data.copy()
            set_builder_activity_flags(activity, self._dataset_builder)
            self.logger.info(f'Processing point 1 of {num_points}...')
            for i in range(num_points):
                if ((i + 1) % 10000) == 0:
                    self.logger.info(f'Processing point {i+1} of {num_points}...')
                inactive = check_activity and not activity[i]
                if inactive or np.isnan(data[i]) or lnd[i] == 1:
                    continue  # Point is inactive
                neighbors = set(dset_ugrid.get_point_adjacent_points(i))
                if num_levels == 2:  # Get all the neighbors of our neighbors
                    neighbors_neighbors = set()
                    for neighbor in neighbors:
                        neighbors_neighbors.update(dset_ugrid.get_point_adjacent_points(neighbor))
                    neighbors.update(neighbors_neighbors)
                if simple_average:  # Average the values of all the point's neighbors.
                    value = 0.0
                    num_contributing_values = 0
                    for neighbor in neighbors:
                        neighbor_inactive = np.isnan(data[neighbor]) or (check_activity and not activity[neighbor])
                        if not neighbor_inactive:
                            num_contributing_values += 1
                            value += data[neighbor]
                    if num_contributing_values > 0:
                        point_value = smoothed_data[i] * point_weight
                        neighbor_value = value / num_contributing_values * weight
                        smoothed_data[i] = point_value + neighbor_value
                else:  # IDW
                    sumd = 0.0
                    p0 = locations[i]
                    distances = []
                    scalars = []
                    for neighbor in neighbors:
                        inactive = check_activity and not activity[neighbor]
                        if inactive or neighbor == i or np.isnan(data[neighbor]):
                            continue  # Neighbor is inactive
                        p1 = locations[neighbor]
                        distances.append(1 / math.sqrt(((p1[0] - p0[0])**2) + ((p1[1] - p0[1])**2)))
                        scalars.append(data[neighbor])
                        sumd += distances[-1]
                    interp_vals = []
                    for scalar, distance in zip(scalars, distances):
                        w = distance / sumd
                        interp_vals.append(w * scalar)
                    if interp_vals:
                        smoothed_data[i] = sum(interp_vals)
            self._dataset_builder.append_timestep(self._dataset_reader.times[tsidx], smoothed_data, activity)

        self._add_output_datasets()

    def _setup_output_dataset_builders(self, arguments):
        """Set up dataset builders for selected tool outputs.

        Args:
            arguments (list): The tool arguments.
        """
        # Create a place for the output dataset file
        dataset_name = arguments[ARG_OUTPUT_DATASET].text_value
        self._dataset_builder = self.get_output_dataset_writer(
            name=dataset_name,
            geom_uuid=self._dataset_reader.geom_uuid,
            num_components=1,
            ref_time=self._dataset_reader.ref_time,
            time_units=self._dataset_reader.time_units,
            null_value=self._dataset_reader.null_value,
        )

    def _add_output_datasets(self):
        """Add datasets created by the tool to be sent back to XMS."""
        self.logger.info('Writing smoothed output dataset to XMDF file...')
        self._dataset_builder.appending_finished()
        # Send the dataset back to XMS
        self.set_output_dataset(self._dataset_builder)
