"""SmoothDatasetsTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.constraint.ugrid_activity import CellToPointActivityCalculator
from xms.gdal.utilities import gdal_utils as gu
from xms.mesher.meshing import mesh_utils
from xms.tool_core import ALLOW_ONLY_POINT_MAPPED, ALLOW_ONLY_SCALARS, IoDirection, Tool

# 4. Local modules
from .. utilities.dataset_tool import set_builder_activity_flags, smooth_get_ugrid

ARG_INPUT_DSET = 0
ARG_ANCHOR = 1
ARG_OPTION = 2
ARG_AREA_CHANGE = 3
ARG_MIN_CELL_AREA = 4
ARG_MAX_SLOPE = 5
ARG_INPUT_LND = 6
ARG_OUTPUT_DATASET = 7


class SmoothDatasetsTool(Tool):
    """Tool to smooth datasets by limiting slope or element area change."""

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Smooth Datasets')
        self._dataset_reader = None
        self._dataset_builder = None
        self._lnd = None

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.dataset_argument(name='input_dataset', description='Input elevation dataset',
                                  filters=[ALLOW_ONLY_SCALARS, ALLOW_ONLY_POINT_MAPPED]),
            self.string_argument(name='anchor_type_option', description='Anchor',
                                 choices=['Minimum value', 'Maximum value'], value='Minimum value'),
            self.string_argument(name='smooth_by_option', description='Smoothing option',
                                 choices=['Elemental area change', 'Maximum slope'], value='Maximum slope'),
            self.float_argument(name='area_change_limit', description='Smoothing area change limit', value=0.5,
                                min_value=0.0001),
            self.float_argument(name='minimum_cell_size', description='Smoothing minimum cell size', value=1.0,
                                min_value=0.0001),
            self.float_argument(name='maximum_slope_value', description='Smoothing maximum slope', value=0.5,
                                min_value=0.0001),
            self.dataset_argument(name='locked_nodes_dataset', description='Locked nodes dataset (optional)',
                                  optional=True, filters=[ALLOW_ONLY_SCALARS, ALLOW_ONLY_POINT_MAPPED]),
            self.dataset_argument(name='default_name', description='Output dataset', value="new dataset",
                                  io_direction=IoDirection.OUTPUT),
        ]

        self.enable_arguments(arguments)
        return arguments

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}

        # Validate input datasets
        self._dataset_reader = self._validate_input_dataset(arguments[ARG_INPUT_DSET], errors)

        # Locked nodes dataset
        if arguments[ARG_INPUT_LND].text_value:
            self._lnd = self._validate_input_dataset(arguments[ARG_INPUT_LND], errors)
            if self._lnd and self._dataset_reader and self._lnd.num_values != \
                    self._dataset_reader.num_values:
                errors[arguments[ARG_INPUT_LND].name] = \
                    'Number of values in locked nodes dataset must match number of values in input dataset.'
        return errors

    def enable_arguments(self, arguments):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.
        """
        # Enable/disable float edit fields based on the smooth_by_option
        slope_enabled = arguments[ARG_OPTION].text_value == 'Maximum slope'
        arguments[ARG_AREA_CHANGE].hide = slope_enabled  # area_change_limit
        arguments[ARG_MIN_CELL_AREA].hide = slope_enabled  # minimum_cell_size
        arguments[ARG_MAX_SLOPE].hide = not slope_enabled  # maximum_slope_value

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        dset_ugrid = smooth_get_ugrid(self, arguments[ARG_INPUT_DSET])

        # Set up an activity array based on the optional locked nodes dataset
        lnd = None
        if self._lnd:
            lnd = self._lnd.values[0].astype('i4')

        # get the current projection
        sr = None
        if gu.valid_wkt(self.default_wkt):
            sr = gu.wkt_to_sr(self.default_wkt)

        # Extract the component, time step at a time
        self._setup_output_dataset_builders(arguments)

        # Set up activity calculator
        if self._dataset_reader.activity and self._dataset_reader.values:
            if self._dataset_reader.values.shape != self._dataset_reader.activity.shape:
                # Nodal dataset values with cell activity
                self._dataset_reader.activity_calculator = CellToPointActivityCalculator(dset_ugrid)
                self._dataset_builder.activity_calculator = CellToPointActivityCalculator(dset_ugrid)

        time_count = len(self._dataset_reader.times)
        for tsidx in range(time_count):
            self.logger.info(f'Processing time step {tsidx + 1} of {time_count}...')
            data, activity = self._dataset_reader.timestep_with_activity(tsidx)
            set_builder_activity_flags(activity, self._dataset_builder)

            anchor_type = 'min' if arguments[ARG_ANCHOR].text_value == 'Minimum value' else 'max'
            if arguments[ARG_OPTION].text_value == 'Maximum slope':
                smoothed_data = mesh_utils.smooth_elev_by_slope_ugrid(
                    dset_ugrid, data, arguments[ARG_MAX_SLOPE].value, anchor_type, []
                )
            else:  # Elemental area change
                smoothed_data = mesh_utils.smooth_size_function_ugrid(
                    dset_ugrid, data, arguments[ARG_AREA_CHANGE].value,
                    arguments[ARG_MIN_CELL_AREA].value, anchor_type, []
                )
            if lnd is not None:
                smoothed_data = np.array(smoothed_data)
                smoothed_data[lnd == 1] = data[lnd == 1]
            self._dataset_builder.append_timestep(self._dataset_reader.times[tsidx], smoothed_data, activity)
            if sr is not None and sr.IsGeographic() and self._dataset_reader.mins[tsidx] > 1.0:
                msg = 'Dataset values may have incorrect units. The current horizontal units are ' \
                      'degrees (geographic) and all dataset values are greater than 1.0. Consider ' \
                      'converting the dataset values to decimal degrees.'
                self.logger.warning(msg)

        self._add_output_datasets()

    def _setup_output_dataset_builders(self, arguments):
        """Set up dataset builders for selected tool outputs.

        Args:
            arguments (list): The tool arguments.
        """
        # Create a place for the output dataset file
        dataset_name = arguments[ARG_OUTPUT_DATASET].text_value
        self._dataset_builder = self.get_output_dataset_writer(
            name=dataset_name,
            geom_uuid=self._dataset_reader.geom_uuid,
            num_components=1,
            ref_time=self._dataset_reader.ref_time,
            time_units=self._dataset_reader.time_units,
            null_value=self._dataset_reader.null_value,
        )

    def _add_output_datasets(self):
        """Add datasets created by the tool to be sent back to XMS."""
        self.logger.info('Writing smoothed output dataset to XMDF file...')
        self._dataset_builder.appending_finished()
        # Send the dataset back to XMS
        self.set_output_dataset(self._dataset_builder)


# def main():
#     """Main function, for testing."""
#     from xms.tool_gui.tool_dialog import ToolDialog
#     from xms.guipy.dialogs.xms_parent_dlg import ensure_qapplication_exists
#
#     qapp = ensure_qapplication_exists()
#     tool = SmoothDatasetsTool()
#     # slope # tool.set_grid_uuid_for_testing('smooth_elev', '12728793-53f1-41e3-b0d0-fd83a3d0ec09')
#     tool.set_grid_uuid_for_testing('smooth_area', '287f2236-1f8f-41b3-b720-3bf44f789825')
#     arguments = tool.initial_arguments()
#
#     arguments[ARG_INPUT_DSET].value = 'smooth_area'  # input_dataset
#     arguments[ARG_ANCHOR].value = 'Minimum value'  # anchor_type_option ('Minimum value' or 'Maximum value')
#     arguments[ARG_OPTION].value = 'Elemental area change'  # smooth_by_option ('Elemental area change' or 'Max slope')
#     # arguments[ARG_AREA_CHANGE].value = 0.5  # area_change_limit
#     arguments[ARG_MIN_CELL_AREA].value = 10.0  # minimum_cell_size
#     # arguments[ARG_MAX_SLOPE].value = 0.05  # maximum_slope_value
#     arguments[ARG_OUTPUT_NAME].value = 'smoothed_area'  # default_name
#     tool_dialog = ToolDialog(None, arguments, tool.name, tool=tool)
#     if tool_dialog.exec():
#         tool.run_tool(tool_dialog.tool_arguments)
#
#
# if __name__ == "__main__":
#     main()
