"""CompareDatasetsTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os
from typing import List

# 2. Third party modules
import numpy

# 3. Aquaveo modules
from xms.constraint.ugrid_activity import values_with_nans
from xms.core.filesystem import filesystem
from xms.datasets.dataset_reader import DatasetReader
from xms.datasets.dataset_writer import DatasetWriter
from xms.grid.ugrid import UGrid
from xms.tool_core import Argument, IoDirection

# 4. Local modules
from xms.tool.datasets import interpolate_to_ugrid as itu


class CompareDatasetsTool(itu.InterpolateToUGridTool):
    """Tool to compare two datasets."""
    ARG_DATASET_1 = itu.ARG_INPUT_DATASET
    ARG_DATASET_2 = itu.ARG_INPUT_TARGET_UGRID
    ARG_INACTIVE_OPTION = itu.ARG_END
    ARG_INACTIVE_VALUE_1 = itu.ARG_END + 1
    ARG_INACTIVE_VALUE_2 = itu.ARG_END + 2
    ARG_DATASET_OUT = itu.ARG_END + 3

    def __init__(self):
        """Initializes the class."""
        super().__init__()
        self.name = 'Compare Datasets'

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = super().initial_arguments()

        # rename the dataset arguments to the applicable names here
        arguments[self.ARG_DATASET_1].name = 'dataset_1'
        arguments[self.ARG_DATASET_1].description = 'Dataset 1'
        arguments[self.ARG_DATASET_2].name = 'dataset_2'
        arguments[self.ARG_DATASET_2] = self.dataset_argument(name='dataset', description='Dataset 2')

        # add the comparing arguments to the list
        inactive_choices = ['Use specified value for inactive values', 'Inactive values result in an inactive value']
        arguments.append(self.string_argument(name='inactive_option',
                                              description='Inactive values option',
                                              choices=inactive_choices))
        arguments.append(self.float_argument(name='inactive_value_1', description='Specified value dataset 1'))
        arguments.append(self.float_argument(name='inactive_value_2', description='Specified value dataset 2'))
        arguments.append(self.dataset_argument(name='dataset_out', description='Output dataset name',
                         io_direction=IoDirection.OUTPUT))

        self.enable_arguments(arguments)
        return arguments

    def needs_interp(self, arguments):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.

        Returns:
            (bool): Are interpolation options needed
        """
        # do the datasets come from different geometries?
        if arguments[self.ARG_DATASET_1].value is None or arguments[self.ARG_DATASET_2].value is None:
            return False

        geometry_1 = self.get_input_dataset_grid(arguments[self.ARG_DATASET_1].value)
        geometry_2 = self.get_input_dataset_grid(arguments[self.ARG_DATASET_2].value)

        if geometry_1.uuid == geometry_2.uuid:
            return False

        return True

    def _convert_args_for_interp(self, arguments):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.

        Returns:
            (list): arguments for interp tool
        """
        interp_args = arguments.copy()

        # dataset 1 (interp tool) needs to be dataset 2 (compare tool)
        interp_args[itu.ARG_INPUT_DATASET] = arguments[self.ARG_DATASET_2]

        # target grid (interp tool) needs to match geometry of dataset 1 (compare tool)
        interp_args[itu.ARG_INPUT_TARGET_UGRID] = self.grid_argument(name='target_grid', description='Target grid')
        dataset = self.get_input_dataset(arguments[self.ARG_DATASET_1].text_value)
        target_grid_uuid = dataset.geom_uuid
        path = self.get_grid_name_from_uuid(target_grid_uuid)
        interp_args[itu.ARG_INPUT_TARGET_UGRID].value = path

        # set the name of the interpolated dataset
        interp_args[itu.ARG_OUTPUT_DATASET_NAME].value = 'tmp_interp_for_tool'

        return interp_args

    def enable_arguments(self, arguments: List[Argument]):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.
        """
        # don't do this until all the arguments have been created
        if len(arguments) < self.ARG_INACTIVE_VALUE_2:
            return

        arguments[itu.ARG_INPUT_DS_TIMESTEP].hide = True
        if self.needs_interp(arguments) is True:
            # this will handle of the interpolation arguments
            for i in range(itu.ARG_INPUT_TARGET_UGRID + 1, itu.ARG_END):
                arguments[i].hide = False

            super().enable_arguments(arguments)

            # hide the interpolation output dataset
            arguments[itu.ARG_OUTPUT_DATASET_NAME].hide = True
            arguments[itu.ARG_INPUT_DATASET_LOCATION].hide = True
        else:
            for i in range(itu.ARG_INPUT_TARGET_UGRID + 1, itu.ARG_END):
                arguments[i].hide = True

        hide = arguments[self.ARG_INACTIVE_OPTION].text_value != 'Use specified value for inactive values'
        arguments[self.ARG_INACTIVE_VALUE_1].hide = hide
        arguments[self.ARG_INACTIVE_VALUE_2].hide = hide

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}

        # Validate input datasets
        dataset1 = self._validate_input_dataset(arguments[self.ARG_DATASET_1], errors)
        dataset2 = self._validate_input_dataset(arguments[self.ARG_DATASET_2], errors)
        if not dataset1 or not dataset2:
            return errors
        self._validate_dataset_compatibility(dataset1, dataset2, arguments, errors)

        if self.needs_interp(arguments) is True:
            # change arguments to match what the interpolator tool wants
            tmp_arguments = self._convert_args_for_interp(arguments)
            interp_errors = super().validate_arguments(tmp_arguments)
            errors = {**errors, **interp_errors}
            pass

        return errors

    def copy_output_dataset(self) -> None:
        """Override so that the parent interpolation tool does not set the output dataset."""

    def _validate_dataset_compatibility(self, dataset1, dataset2, arguments, errors):
        """Check that datasets have the same number of values and the same number of time steps.

        Args:
            dataset1 (xms.datasets.dset_reader.DatasetReader): The first dataset.
            dataset2 (xms.datasets.dset_reader.DatasetReader): The second dataset.
            arguments (list): The tool arguments.
            errors (dict): Dictionary of errors keyed by argument name. Gets modified if there are errors.
        """
        arg3 = arguments[self.ARG_DATASET_2]  # The second dataset

        if len(dataset1.times) != len(dataset2.times):
            errors[arg3.name] = 'Datasets must have the same number of time steps.'
            return

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        # Read the datasets
        dataset_1 = self.get_input_dataset(arguments[self.ARG_DATASET_1].text_value)
        ugrid_1 = self.get_input_dataset_grid(arguments[self.ARG_DATASET_1].text_value).ugrid
        has_activity = dataset_1.activity is not None
        has_activity = has_activity or dataset_1.null_value is not None

        # do interpolation first if necessary for dataset_2
        if self.needs_interp(arguments):
            self.logger.info('Doing interpolation...')
            interp_args = self._convert_args_for_interp(arguments)

            # running the interpolation will create the dataset "tmp_interp_for_tool"
            super().run(interp_args)
            dataset_2 = DatasetReader(h5_filename=self._output_ds.h5_filename, dset_name=self._output_ds.name)
        else:
            dataset_2 = self.get_input_dataset(arguments[self.ARG_DATASET_2].text_value)
        has_activity = has_activity or dataset_2.activity is not None
        has_activity = has_activity or dataset_2.null_value is not None

        # Give a message if the timesteps are not the same
        for i in range(len(dataset_1.times)):
            if dataset_1.times[i] != dataset_2.times[i]:
                self.logger.info('WARNING: The timesteps for both datasets do not match.')
                break

        # Get the options for inactive values
        specified_option = 'Use specified value for inactive values'
        use_specified = arguments[self.ARG_INACTIVE_OPTION].text_value == specified_option
        active_value_1 = arguments[self.ARG_INACTIVE_VALUE_1].value if use_specified else None
        active_value_2 = arguments[self.ARG_INACTIVE_VALUE_2].value if use_specified else None

        # Create a place for the output dataset file
        temp_file_path = filesystem.temp_filename()
        os.mkdir(temp_file_path)
        dataset_name = os.path.splitext(os.path.basename(arguments[self.ARG_DATASET_OUT].text_value))[0]
        temp_file_path = os.path.join(temp_file_path, dataset_name)

        builder = DatasetWriter(
            h5_filename=temp_file_path,
            name=dataset_name,
            geom_uuid=dataset_1.geom_uuid,
            num_components=dataset_1.num_components,
            ref_time=dataset_1.ref_time,
            time_units=dataset_1.time_units,
            null_value=dataset_1.null_value,
            location=dataset_1.location
        )

        # Compute the difference, time step at a time
        time_count = len(dataset_1.times)
        minimums = []
        maximums = []
        for time_index in range(time_count):
            values_1 = self._get_values(ugrid_1, dataset_1, time_index)
            values_2 = self._get_values(ugrid_1, dataset_2, time_index)
            diff = values_1 - values_2
            if use_specified is not None or use_specified is not None:
                # apply specified value when only one side is inactive
                inactive_1 = numpy.isnan(values_1)
                inactive_2 = numpy.isnan(values_2)
                one_inactive = numpy.logical_xor(inactive_1, inactive_2)
                if use_specified:
                    # when only dataset 1 is active set value to specified value 1
                    only_active_1 = numpy.logical_and(inactive_2, one_inactive)
                    diff[only_active_1] = active_value_1
                if use_specified:
                    # when only dataset 2 is active set value to specified value 2
                    only_active_2 = numpy.logical_and(inactive_1, one_inactive)
                    diff[only_active_2] = active_value_2
            if has_activity:
                activity = numpy.ones(shape=diff.shape, dtype='u1')
                activity[numpy.isnan(diff)] = 0
            else:
                activity = None

            # Append the stuff for this time step
            minimums.append(numpy.nanmin(diff))
            maximums.append(numpy.nanmax(diff))
            diff[numpy.isnan(diff)] = -999.0
            builder.append_timestep(dataset_1.times[time_index], diff, activity)

        # Write the file
        builder.timestep_mins = minimums
        builder.timestep_maxs = maximums
        builder.appending_finished()

        # Send the dataset back to XMS
        self.set_output_dataset(builder)

    @staticmethod
    def _get_values(ugrid: UGrid, dataset: DatasetReader, time_index: int):
        """Get dataset values for a timestep with inactive values as NAN.

        Args:
            ugrid: The dataset's unstructured grid.
            dataset: The dataset.
            time_index: The time index.

        Returns:
            The dataset values.
        """
        raw_values = dataset.values[time_index]
        activity = None if dataset.activity is None else dataset.activity[time_index]
        values = values_with_nans(ugrid, raw_values, activity, dataset.null_value)
        return values
