"""FilterDatasetTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from typing import Any, Dict, Optional

# 2. Third party modules
import numpy
import numpy as np

# 3. Aquaveo modules
from xms.constraint.ugrid_activity import CellToPointActivityCalculator
from xms.tool_core import ALLOW_ONLY_SCALARS, Argument, IoDirection, Tool

# 4. Local modules
from xms.tool.utilities.time_units_converter import TimeUnitsConverter


class FilterDatasetTool(Tool):
    """Tool to filter dataset values based on conditionals."""
    ARG_INPUT_DATASET = 0
    ARG_IF_CONDITION = 1
    ARG_IF_VALUE = 2
    ARG_BOOLEAN = 3
    ARG_CONDITION_2 = 4
    ARG_CONDITION_VALUE_2 = 5
    ARG_ASSIGNMENT_TYPE = 6
    ARG_SPECIFIED_VALUE = 7
    ARG_TIME_UNITS = 8
    ARG_DEFAULT_ASSIGNED_VALUE = 9
    ARG_DEFAULT_SPECIFIED_VALUE = 10
    ARG_OUTPUT_DATASET = 11

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Filter Dataset')
        self._input_dataset = None
        self._output_dataset = None

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        comparison_choices = [Argument.NONE_SELECTED, '<', '<=', '>', '>=', 'Equal', 'Not equal', 'Null', 'Not null']
        arguments = [
            self.dataset_argument(name='input_dataset', description='Input dataset',
                                  filters=[ALLOW_ONLY_SCALARS]),
            self.string_argument(name='if_condition', description='If condition',
                                 choices=comparison_choices),
            self.float_argument(name='if_value', description='If value'),
            self.string_argument(name='boolean', description='Boolean',
                                 choices=[Argument.NONE_SELECTED, 'And', 'Or'], optional=True),
            self.string_argument(name='condition_2', description='Condition', choices=comparison_choices,
                                 optional=True),
            self.float_argument(name='condition_value_2', description='Condition value', optional=True),
            self.string_argument(name='assignment_type', description='Assign on true',
                                 choices=['Original', 'Specify', 'Null', 'True', 'False', 'Time']),
            self.float_argument(name='specified_value', description='Specified value', optional=True, value=0.0),
            self.string_argument(name='time_units', description='Time units',
                                 choices=['Seconds', 'Minutes', 'Hours', 'Days'], optional=True),
            self.string_argument(name='default_assigned_value', description='Assign on false',
                                 choices=['Original', 'Specify', 'Null', 'True', 'False']),
            self.float_argument(name='default_specified_value', description='Default specified value', optional=True,
                                value=0.0),
            self.dataset_argument(name='output_dataset', description='Output dataset', value="new dataset",
                                  io_direction=IoDirection.OUTPUT),
        ]

        self.enable_arguments(arguments)
        return arguments

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}

        # Validate input dataset
        self._input_dataset = self._validate_input_dataset(arguments[self.ARG_INPUT_DATASET], errors)

        # if boolean set then must specify condition_2 and condition_value_2
        if arguments[self.ARG_BOOLEAN].text_value != '':
            if arguments[self.ARG_CONDITION_2].text_value == '':
                errors[arguments[self.ARG_CONDITION_2].name] = 'Value must be specified.'
            if arguments[self.ARG_CONDITION_VALUE_2].text_value == '':
                errors[arguments[self.ARG_CONDITION_VALUE_2].name] = 'Value must be specified.'

        # if assignment_type is 'Specify' then must have a specified_value
        if arguments[self.ARG_ASSIGNMENT_TYPE].value == 'Specify':
            if arguments[self.ARG_SPECIFIED_VALUE].text_value == '':
                errors[arguments[self.ARG_SPECIFIED_VALUE].name] = 'Value must be specified.'

        # if assignment_type is 'Time' then must have time_units
        if arguments[self.ARG_ASSIGNMENT_TYPE].value == 'Time':
            if arguments[self.ARG_TIME_UNITS].text_value == '':
                errors[arguments[self.ARG_TIME_UNITS].name] = 'Value must be specified.'

        # if default_assigned_value is 'Specify' then must have a default_specified_value
        if arguments[self.ARG_DEFAULT_ASSIGNED_VALUE].value == 'Specify':
            if arguments[self.ARG_DEFAULT_SPECIFIED_VALUE].text_value == '':
                errors[arguments[self.ARG_DEFAULT_SPECIFIED_VALUE].name] = 'Value must be specified.'

        return errors

    def enable_arguments(self, arguments):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.
        """
        null_values = ('Null', 'Not null')
        # don't show if value when set to null and not null
        needs_if_value = arguments[self.ARG_IF_CONDITION].value not in null_values
        arguments[self.ARG_IF_VALUE].show = needs_if_value

        # if boolean set then must specify condition_2 and condition_value_2
        has_boolean = arguments[self.ARG_BOOLEAN].text_value != ''
        arguments[self.ARG_CONDITION_2].show = has_boolean
        needs_condition_value = has_boolean and arguments[self.ARG_CONDITION_2].value not in null_values
        arguments[self.ARG_CONDITION_VALUE_2].show = needs_condition_value

        # if assignment_type is 'Specify' then must have a specified_value
        has_specified_value = arguments[self.ARG_ASSIGNMENT_TYPE].value == 'Specify'
        arguments[self.ARG_SPECIFIED_VALUE].show = has_specified_value

        # if assignment_type is 'Time' then must have time_units
        has_time_units = arguments[self.ARG_ASSIGNMENT_TYPE].value == 'Time'
        arguments[self.ARG_TIME_UNITS].show = has_time_units

        # if default_assigned_value is 'Specify' then must have a default_specified_value
        has_default_specified_value = arguments[self.ARG_DEFAULT_ASSIGNED_VALUE].value == 'Specify'
        arguments[self.ARG_DEFAULT_SPECIFIED_VALUE].show = has_default_specified_value

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        filter_settings = {
            'if_condition': arguments[self.ARG_IF_CONDITION].value,
            'if_value': arguments[self.ARG_IF_VALUE].value,
            'boolean': arguments[self.ARG_BOOLEAN].value, 'condition_2': arguments[self.ARG_CONDITION_2].value,
            'condition_value_2': arguments[self.ARG_CONDITION_VALUE_2].value,
            'assignment_type': arguments[self.ARG_ASSIGNMENT_TYPE].value,
            'specified_value': arguments[self.ARG_SPECIFIED_VALUE].value}
        using_time_output = filter_settings['assignment_type'] == 'Time'
        time_units = arguments[self.ARG_TIME_UNITS].value
        filter_settings['default_value'] = arguments[self.ARG_DEFAULT_ASSIGNED_VALUE].value
        filter_settings['default_specified_value'] = arguments[self.ARG_DEFAULT_SPECIFIED_VALUE].value

        # Filter the datset, time step at a time
        self._setup_output_dataset_builders(arguments)
        ugrid = self.get_input_dataset_grid(arguments[self.ARG_INPUT_DATASET].value).ugrid

        # Setup activity calculator
        if self._input_dataset.activity and self._input_dataset.values:
            if self._input_dataset.values.shape != self._input_dataset.activity.shape:
                # Nodal dataset values with cell activity
                self._input_dataset.activity_calculator = CellToPointActivityCalculator(ugrid)
                self._output_dataset.activity_calculator = CellToPointActivityCalculator(ugrid)

        time_count = len(self._input_dataset.times)
        for tsidx in range(time_count):
            self.logger.info(f'Processing time step {tsidx + 1} of {time_count}...')
            data, activity = self._input_dataset.timestep_with_activity(tsidx, nan_activity=True)

            if using_time_output:
                time = self._input_dataset.times[tsidx]
                if self._input_dataset.time_units != time_units:
                    converter = TimeUnitsConverter(from_units=self._input_dataset.time_units,
                                                   to_units=time_units)
                    time = converter.convert_value(time)
                filter_settings['time_value'] = time

            data = apply_filter(data, filter_settings)

            # If we replaced all null values to be not null, don't set an activity array
            replaced_nulls = (
                filter_settings['if_condition'] == 'Null'
                and filter_settings['assignment_type'] == 'Specify'
                and not np.isnan(filter_settings['specified_value'])
            )
            if replaced_nulls:
                activity = None

            self._output_dataset.append_timestep(self._input_dataset.times[tsidx], data, activity)

        self._add_output_dataset()

    def _setup_output_dataset_builders(self, arguments):
        """Set up dataset builders for selected tool outputs.

        Args:
            arguments (list): The tool arguments.
        """
        # Create a place for the output dataset file
        null_value = self._input_dataset.null_value
        if null_value is None:
            null_value = -9999999.0
        dataset_name = arguments[self.ARG_OUTPUT_DATASET].text_value
        self._output_dataset = self.get_output_dataset_writer(
            name=dataset_name,
            geom_uuid=self._input_dataset.geom_uuid,
            num_components=1,
            ref_time=self._input_dataset.ref_time,
            time_units=self._input_dataset.time_units,
            null_value=null_value,
            location=self._input_dataset.location
        )

    def _add_output_dataset(self):
        """Add datasets created by the tool to be sent back to XMS."""
        self.logger.info('Writing output dataset...')
        self._output_dataset.appending_finished()
        # Send the dataset back to XMS
        self.set_output_dataset(self._output_dataset)


def apply_filter(data: numpy.ndarray, filter_settings: Dict[str, Any]):
    """Apply filter to values in a numpy array.

    Args:
        data (np.ndarray): The dataset values.
        filter_settings (Dict[str, Any]): Filter values to be applied.

    Returns:
        (np.ndarray): The filtered values.
    """
    # get mask for where to apply filtered values
    data = numpy.copy(data)
    matching_mask = get_conditional_mask(data, filter_settings['if_condition'], filter_settings['if_value'])
    if filter_settings['boolean']:
        mask_2 = get_conditional_mask(data, filter_settings['condition_2'], filter_settings['condition_value_2'])
        if filter_settings['boolean'] == 'And':
            matching_mask = numpy.logical_and(matching_mask, mask_2)
        elif filter_settings['boolean'] == 'Or':
            matching_mask = numpy.logical_or(matching_mask, mask_2)

    # apply values to filtered ('Original', 'Specify', 'Null', 'True', 'False', 'Time')
    if filter_settings['assignment_type'] == 'Original':
        pass  # keeping original value
    elif filter_settings['assignment_type'] == 'Specify':
        data[matching_mask] = filter_settings['specified_value']
    elif filter_settings['assignment_type'] == 'Null':
        data[matching_mask] = np.nan
    elif filter_settings['assignment_type'] == 'True':
        data[matching_mask] = 1.0
    elif filter_settings['assignment_type'] == 'False':
        data[matching_mask] = 0.0
    elif filter_settings['assignment_type'] == 'Time':
        data[matching_mask] = filter_settings['time_value']

    # apply default values ('Original', 'Specify', 'Null', 'True', 'False')
    default_mask = numpy.logical_not(matching_mask)
    if filter_settings['default_value'] == 'Original':
        pass  # keeping original value
    elif filter_settings['default_value'] == 'Specify':
        data[default_mask] = filter_settings['default_specified_value']
    elif filter_settings['default_value'] == 'Null':
        data[default_mask] = np.nan
    elif filter_settings['default_value'] == 'True':
        data[default_mask] = 1.0
    elif filter_settings['default_value'] == 'False':
        data[default_mask] = 0.0

    return data


def get_conditional_mask(data: numpy.ndarray, condition: str, if_value: Optional[float]):
    """Get a mask for where a numpy array meets a given conditional value.

    Args:
        data (numpy.ndarray): The numpy array.
        condition (str): The conditional value ('<', '<=', '>', '>=', 'Equal', 'Not equal', 'Null', 'Not null').
        if_value (Optional[float]): The value to compare against.

    Returns:
        (numpy.ndarray): The mask.
    """
    if condition == '<':
        mask_1 = data < if_value
    elif condition == '<=':
        mask_1 = data <= if_value
    elif condition == '>':
        mask_1 = data > if_value
    elif condition == '>=':
        mask_1 = data >= if_value
    elif condition == 'Equal':
        mask_1 = data == if_value
    elif condition == 'Not equal':
        mask_1 = data != if_value
    elif condition == 'Null':
        mask_1 = np.isnan(data)
    elif condition == 'Not null':
        mask_1 = np.logical_not(np.isnan(data))
    else:
        raise ValueError('Invalid if condition.')
    return mask_1

# def main():
#     """Main function, for testing."""
#     from xms.guipy.dialogs.xms_parent_dlg import ensure_qapplication_exists
#     from xms.tool.utilities.test_utils import get_test_files_path
#     from xms.tool_gui.tool_dialog import ToolDialog
#
#     qapp = ensure_qapplication_exists()
#     tool = FilterDatasetTool()
#     tool.set_gui_data_folder(os.path.join(get_test_files_path(), 'filter_dataset_tool'))
#     tool.set_grid_uuid_for_testing('grid', '160c179f-2d0d-4376-873c-ecf64edc55ba')
#     arguments = tool.initial_arguments()
#     tool_dialog = ToolDialog(None, arguments, tool.name, tool=tool)
#     if tool_dialog.exec():
#         tool.run_tool(tool_dialog.tool_arguments)
#
#
# if __name__ == "__main__":
#     main()
