"""VectorsFromScalarsTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.tool_core import ALLOW_ONLY_SCALARS, IoDirection, Tool

# 4. Local modules
from xms.tool.algorithms.datasets.vectors_from_scalars import vectors_from_scalars

ARG_INPUT_OPTION = 0
ARG_MAG_INPUT = 1
ARG_DIR_INPUT = 2
ARG_OUTPUT_DATASET = 3


class VectorsFromScalarsTool(Tool):
    """Tool to convert scalar datasets to a vector dataset."""

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Vectors from Scalars')
        self._x_reader = None
        self._y_reader = None
        self._builder = None

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.string_argument(name='xy_or_magdir_input', description='Input type',
                                 choices=['Vx and Vy', 'Magnitude and Direction'], value='Vx and Vy'),
            self.dataset_argument(name='x_or_mag_input_dataset', description='Input Vx or magnitude scalar dataset',
                                  filters=ALLOW_ONLY_SCALARS),
            self.dataset_argument(name='y_or_dir_input_dataset', description='Input Vy or direction scalar dataset',
                                  filters=ALLOW_ONLY_SCALARS),
            self.dataset_argument(name='default_name', description='Output vector dataset name',
                                  value="new dataset", io_direction=IoDirection.OUTPUT),
        ]
        self.enable_arguments(arguments)
        return arguments

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}

        # Validate input datasets
        self._x_reader = self._validate_input_dataset(arguments[ARG_MAG_INPUT], errors)
        self._y_reader = self._validate_input_dataset(arguments[ARG_DIR_INPUT], errors)
        if self._x_reader is None or self._y_reader is None:
            return errors

        if not np.array_equal(self._x_reader.times, self._y_reader.times):
            errors[arguments[ARG_MAG_INPUT].name] = 'Mismatched time steps in the input datasets.'
        # Check number of values in first timestep. Jagged arrays not allowed.
        if len(self._x_reader.values[0]) != len(self._y_reader.values[0]):
            errors[arguments[ARG_DIR_INPUT].name] = 'Mismatched number of dataset values in the input datasets.'

        return errors

    def enable_arguments(self, arguments):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.
        """

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        # Setup inputs for algorithm
        x_reader = self._x_reader
        y_reader = self._y_reader
        x_grid = self.get_input_dataset_grid(arguments[ARG_MAG_INPUT].value)
        y_grid = self.get_input_dataset_grid(arguments[ARG_DIR_INPUT].value)
        input_type = arguments[ARG_INPUT_OPTION].value
        output_dataset_name = arguments[ARG_OUTPUT_DATASET].text_value
        logger = self.logger

        # Run Algorithm
        builder = vectors_from_scalars(x_reader, y_reader, x_grid, y_grid, input_type, output_dataset_name, logger)

        # Add created datasets to the tool to be sent back to XMS
        self._builder = builder
        self._add_output_datasets()

    def _add_output_datasets(self):
        """Add datasets created by the tool to be sent back to XMS."""
        if self._builder is not None:
            # Send the dataset back to XMS
            self.set_output_dataset(self._builder)


# def main():
#     """Main function, for testing."""
#     from xms.tool_gui.tool_dialog import ToolDialog
#     from xms.guipy.dialogs.xms_parent_dlg import ensure_qapplication_exists
#
#     qapp = ensure_qapplication_exists()
#     tool = VectorsFromScalarsTool()
#     arguments = tool.initial_arguments()
#
#     test_xy = False
#     if test_xy:
#         arguments[S2V_INPUT_OPTION_IDX].value = 'Vx and Vy'  # xy_or_magdir_input
#         arguments[S2V_MAG_INPUT_IDX].value = 's2v_mag'  # x_or_mag_input_dataset
#         arguments[S2V_DIR_INPUT_IDX].value = 's2v_dir'  # y_or_dir_input_dataset
#         arguments[S2V_OUTPUT_PREFIX_IDX].value = 's2v'  # default_name
#     else:
#         arguments[S2V_INPUT_OPTION_IDX].value = 'Magnitude and Direction'
#         arguments[S2V_MAG_INPUT_IDX].value = 's2v_mag'
#         arguments[S2V_DIR_INPUT_IDX].value = 's2v_dir'
#         arguments[S2V_OUTPUT_PREFIX_IDX].value = 's2v_mag_dir'
#
#     tool_dialog = ToolDialog(None, arguments, tool.name, tool=tool)
#     if tool_dialog.exec():
#         tool.run_tool(tool_dialog.tool_arguments)
#
#
# if __name__ == "__main__":
#     main()
