"""TransientDatasetStatisics tool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules

# 3. Aquaveo modules
from xms.datasets.dataset_io import DSET_NULL_VALUE
from xms.tool_core import dataset_argument, IoDirection, Tool

# 4. Local modules
from xms.tool.algorithms.datasets.transient_dataset_statistics_calculator import TransientDatasetStatisticCalculator

ARG_DATASET_CHOICE = 0
ARG_OUTPUT_DATASET = 1
ARG_OPERATION_MIN = 2
ARG_OPERATION_MAX = 3
ARG_OPERATION_STDEV = 4
ARG_OPERATION_MEAN = 5
ARG_OPERATION_MEDIAN = 6


class TransientDatasetStatisticsCalculatorTool(Tool):
    """Tool for performing statistical operations on a dataset."""

    def __init__(self):
        """Create an instance of the Transient Dataset Statistics tool."""
        super().__init__(name="Transient Dataset Statistics")

    def initial_arguments(self) -> list[dataset_argument]:
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.dataset_argument(name='dataset_selection', description='Dataset', io_direction=IoDirection.INPUT),
            self.string_argument(name='dataset_prefix', description='Dataset name prefix'),
            self.bool_argument(name='minimum', description='Minimum'),
            self.bool_argument(name='maximum', description='Maximum'),
            self.bool_argument(name='standard_deviation', description='Standard Deviation'),
            self.bool_argument(name='mean', description='Mean'),
            self.bool_argument(name='median', description='Median'),
            self.dataset_argument(name='min_dataset', description='Minimum dataset name',
                                  io_direction=IoDirection.OUTPUT, optional=True, hide=True),
            self.dataset_argument(name='max_dataset', description='Maximum dataset name',
                                  io_direction=IoDirection.OUTPUT, optional=True, hide=True),
            self.dataset_argument(name='stdev_dataset', description='Standard deviation dataset name',
                                  io_direction=IoDirection.OUTPUT, optional=True, hide=True),
            self.dataset_argument(name='mean_dataset', description='Mean dataset name',
                                  io_direction=IoDirection.OUTPUT, optional=True, hide=True),
            self.dataset_argument(name='median_dataset', description='Median dataset name',
                                  io_direction=IoDirection.OUTPUT, optional=True, hide=True),
        ]

        return arguments

    def validate_arguments(self, arguments: list[dataset_argument]) -> dict:
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}

        if not any([arguments[ARG_OPERATION_MIN].value, arguments[ARG_OPERATION_MAX].value,
                    arguments[ARG_OPERATION_STDEV].value, arguments[ARG_OPERATION_MEAN].value,
                    arguments[ARG_OPERATION_MEDIAN].value]):
            errors[arguments[ARG_OPERATION_MIN].name] = 'At least one operation must be selected.'
        return errors

    def run(self, arguments: list[dataset_argument]) -> None:
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        grid = self.get_input_dataset_grid(arguments[ARG_DATASET_CHOICE].text_value)
        geom_uuid = grid.uuid
        operations_list = [arguments[ARG_OPERATION_MIN], arguments[ARG_OPERATION_MAX],
                           arguments[ARG_OPERATION_STDEV], arguments[ARG_OPERATION_MEAN],
                           arguments[ARG_OPERATION_MEDIAN]]
        selected_operations = []
        for operation in operations_list:
            if getattr(operation, 'value', True):
                selected_operations.append(operation)
        output_name = arguments[ARG_OUTPUT_DATASET].text_value
        output_datasets = []

        dataset = self.get_input_dataset(arguments[ARG_DATASET_CHOICE].text_value)
        location = dataset.location
        ref_time = dataset.ref_time
        time_units = dataset.time_units

        for i, operation in enumerate(selected_operations):
            output_datasets.append(self.get_output_dataset_writer(
                name=f"{output_name}_{operation.description}",
                geom_uuid=geom_uuid,
                num_components=1,
                ref_time=ref_time,
                time_units=time_units,
                location=location,
                use_activity_as_null=True
            ))
            output_datasets[i].null_value = dataset.null_value if dataset.null_value is not None else DSET_NULL_VALUE

        transient_stats = TransientDatasetStatisticCalculator(dataset, selected_operations, output_datasets, grid.ugrid,
                                                              self.logger)
        transient_stats.calculate()

        for idx in range(len(output_datasets)):
            self.set_output_dataset(output_datasets[idx])
