"""DatasetCalculator tool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import keyword

# 2. Third party modules
import numexpr.necompiler as nec
import numpy as np
import pandas as pd

# 3. Aquaveo modules
from xms.tool_core import ALLOW_ONLY_CELL_MAPPED, ALLOW_ONLY_POINT_MAPPED, ALLOW_ONLY_SCALARS, IoDirection, Tool
from xms.tool_core.table_definition import ChoicesColumnType, StringColumnType, TableDefinition

# 4. Local modules
from xms.tool.algorithms.datasets.dataset_calculator import DatasetCalculator, get_activity_type

# Column headers for table argument dataframe
DATASET_HEADER = "Dataset"
VARIABLE_HEADER = "Variable Name"
TIMESTEP_SELECTION_HEADER = "Timestep"
CHOICES_COLUMN_HEADER = "Choices Column"

DATASET_COLUMN_INDEX = 0
VARIABLE_COLUMN_INDEX = 1
TIMESTEP_COLUMN_INDEX = 2

ARG_GRID_SOURCE = 0
ARG_DATA_LOCATION = 1
ARG_TABLE = 2
ARG_INPUT_EXPRESSION = 3
ARG_OUTPUT_DATASET = 4


# Keywords that cannot be used as dataset variable names
RESERVED_VARIABLE_NAMES = [
    # Bitwise operators
    '&', '|', '~', '^',
    # Comparison operators
    '<', '<=', '==', '!=', '>=', '>',
    # Arithmetic operators
    '+', '-', '*', '**', '%', '<<', '>>',
    # Math functions
    'sin', 'cos', 'tan', 'arcsin', 'arccos', 'arctan', 'arctan2', 'sinh', 'cosh', 'tanh', 'arcsinh', 'arccosh',
    'arctanh', 'log', 'log10', 'log1p', 'exp', 'expm1', 'sqrt', 'abs',
    # Conditional and misc functions
    'where', 'contains', 'conj', 'complex', 'sum', 'prod',
]
RESERVED_VARIABLE_NAMES.extend(keyword.kwlist)  # Add all the Python keywords.


class DatasetCalculatorTool(Tool):
    """Tool to compare 1 or more datasets."""

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Dataset Calculator')
        self.previous_table_val = None
        self.previous_grid_val = ""
        self.previous_filter_val = ""
        self.current_variable_num = 1
        self.datasets_list = []
        self.variable_names_list = []
        self.timesteps_list = []
        # import os
        # os.environ['XMSTOOL_GUI_TESTING'] = 'YES'

    def validate_from_history(self, arguments):
        """Validates arguments when loading arguments from history.

        Args:
            arguments: list of Tool arguments.

        Returns:
            True if arguments are valid.
        """
        self.previous_grid_val = arguments[ARG_GRID_SOURCE].value

        # Check if UGrid was removed
        if self.previous_grid_val is None:
            return False

        datasets = self._data_handler.get_grid_datasets(self.previous_grid_val)
        # Check that all datasets can be found and still exist
        for index, row in arguments[ARG_TABLE].value.iterrows():
            dataset = row[DATASET_HEADER]
            if f"{self.previous_grid_val}/{dataset}" not in datasets:
                arguments[ARG_TABLE].value.at[index, DATASET_HEADER] = "Select a dataset"
                arguments[ARG_TABLE].value.at[index, TIMESTEP_SELECTION_HEADER] = ""
                arguments[ARG_TABLE].value.at[index, CHOICES_COLUMN_HEADER] = []

        self.previous_filter_val = arguments[ARG_DATA_LOCATION].value
        self.current_variable_num = len(arguments[ARG_TABLE].value) + 1
        arguments[ARG_TABLE].table_definition = TableDefinition.from_dict(arguments[ARG_TABLE].table_definition)
        arguments[ARG_TABLE].table_definition.column_types[DATASET_COLUMN_INDEX].choices = self.get_datasets(arguments)

        return True

    def create_table_definition(self):
        """Creates a blank/default table definition."""
        tool_tips = [
            "Dataset Selection",
            "Variable Name",
            "Time step Selection"]
        dataset_choices = []
        columns = [
            StringColumnType(header=DATASET_HEADER, tool_tip=tool_tips[0], default="Select a dataset",
                             choices=dataset_choices),
            StringColumnType(header=VARIABLE_HEADER, tool_tip=tool_tips[1], default="Variable name"),
            StringColumnType(header=TIMESTEP_SELECTION_HEADER, tool_tip=tool_tips[2], choices=3,
                             default=""),
            ChoicesColumnType(header=CHOICES_COLUMN_HEADER, default=[[""]])
        ]

        return TableDefinition(columns)

    def get_timesteps(self, dataset):
        """Called to retrieve a dataset's timestep array.

        Args:
            dataset(str): filepath and name of the dataset file.

        Returns:
            (list[tuple]): list of tuples made up of indices and timestep values from the dataset.
        """
        dataset_value = self.get_input_dataset(dataset)
        if dataset_value is None:
            return []  # No dataset selected for this row yet.
        times = pd.DataFrame(dataset_value.times)
        times_result = [times[row] for row in times]

        return times_result

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        self.previous_table_val = None

        table_def = self.create_table_definition()
        df = table_def.to_pandas()

        grid_options = self._data_handler.get_available_grids()
        no_grids_found = len(grid_options) == 0
        if no_grids_found:
            grid_options = [""]

        default_grid = grid_options[0]

        # Check for cell based datasets on the grid and set default data location
        cell_datasets = self._data_handler.get_grid_datasets(default_grid, [ALLOW_ONLY_SCALARS, ALLOW_ONLY_CELL_MAPPED])
        data_location_default = 'Cells' if cell_datasets else 'Points'

        arguments = [
            self.grid_argument(name='grid_source_selection', description='Grid source selection', value=default_grid),
            self.string_argument(name='data_location_selection', description='Data location selection',
                                 choices=['Cells', 'Points'], value=data_location_default, hide=not no_grids_found),
            self.table_argument(name='dataset_table', description="Dataset table", value=df, optional=True,
                                table_definition=table_def, hide=no_grids_found),
            self.string_argument(name='expression', description='Mathematical expression', optional=False,
                                 hide=no_grids_found),
            self.dataset_argument(name='output_dataset', description='Output dataset name', optional=False,
                                  io_direction=IoDirection.OUTPUT, hide=no_grids_found),
        ]
        return arguments

    def get_datasets(self, arguments):
        """
        Get available datasets filtered by chosen data location and from chosen UGrid.

        Args:
            arguments: list of tool arguments

        Returns:
            list[str]: list of dataset names
        """
        # Get filter list for data handler
        current_filter_value = arguments[ARG_DATA_LOCATION].value
        if current_filter_value == "Cells":
            filters_list = [ALLOW_ONLY_SCALARS, ALLOW_ONLY_CELL_MAPPED]
        else:
            filters_list = [ALLOW_ONLY_SCALARS, ALLOW_ONLY_POINT_MAPPED]

        ugrid = arguments[ARG_GRID_SOURCE].value
        datasets = self._data_handler.get_grid_datasets(ugrid, filters_list)
        dataset_names = [dataset[len(ugrid) + 1:] for dataset in datasets]
        return dataset_names

    def enable_change_filter_selections(self, arguments):
        """
        Checks for changes in either the grid source selection or data location selection inputs.

        If there are changes, clear out the dataset table and filter dataset options to use in the table.

        Args:
            arguments: list of tool arguments

        Returns:
            false if there are any changes found, true if not.
        """
        grid_source_selected = arguments[ARG_GRID_SOURCE].value
        current_filter_value = arguments[ARG_DATA_LOCATION].value
        changed_grid = grid_source_selected != self.previous_grid_val
        changed_filter = current_filter_value != self.previous_filter_val

        if changed_grid or changed_filter:
            if grid_source_selected == "":
                arguments[ARG_DATA_LOCATION].show = False
                arguments[ARG_TABLE].show = False
                arguments[ARG_INPUT_EXPRESSION].show = False
                arguments[ARG_OUTPUT_DATASET].show = False
                self.previous_grid_val = ""
                return False
            else:
                arguments[ARG_DATA_LOCATION].show = True
                arguments[ARG_TABLE].show = True
                arguments[ARG_INPUT_EXPRESSION].show = True
                arguments[ARG_OUTPUT_DATASET].show = True

            self.current_variable_num = 1
            filtered_choices = self.get_datasets(arguments)

            table_def = arguments[ARG_TABLE].table_definition
            if len(filtered_choices) == 0:
                table_def.column_types[DATASET_COLUMN_INDEX].choices = ["No datasets found"]
                table_def.column_types[DATASET_COLUMN_INDEX].default = "No datasets found"
            else:
                table_def.column_types[DATASET_COLUMN_INDEX].choices = filtered_choices
                table_def.column_types[DATASET_COLUMN_INDEX].default = "Select a dataset"

            arguments[ARG_TABLE].value = table_def.to_pandas()

            self.previous_grid_val = grid_source_selected
            self.previous_filter_val = current_filter_value
            return False

        return True

    def enable_table_change(self, arguments):
        """
        Checks for changes on the table: updates cells or the whole table as needed.

        Args:
            arguments: list of arguments.
        """
        new_table_val = arguments[ARG_TABLE].value

        # Make sure this isn't the first change being made to the table
        if self.previous_table_val is not None:
            length_old = self.previous_table_val.shape[0]
            length_new = new_table_val.shape[0]

            # If a new row was added
            if length_old < length_new:
                new_variable_name = "d" + str(self.current_variable_num)
                new_table_val.at[length_new, VARIABLE_HEADER] = new_variable_name
                arguments[ARG_TABLE].value = new_table_val
                self.current_variable_num += 1

            # If no rows were added, check if any of the dataset choices were changed
            elif length_old == length_new:
                ugrid_name = arguments[ARG_GRID_SOURCE].value + "/"
                for index, row in new_table_val.iterrows():
                    if row[DATASET_HEADER] != self.previous_table_val[DATASET_HEADER][index]:
                        self.update_timesteps(index, ugrid_name + row[DATASET_HEADER], new_table_val)
                        arguments[ARG_TABLE].value = new_table_val

    def update_timesteps(self, index, dataset_name, new_table_val):
        """
        Retrieves timestep options for a given dataset and populates the timestep cell on table for that dataset.

        Args:
            index: index on dataframes
            dataset_name: chosen dataset name
            new_table_val: dataframe
        """
        timesteps = self.get_timesteps(dataset_name)
        timesteps_strings = []
        if len(timesteps) > 0:  # We may not have selected a dataset for this row.
            if len(timesteps[0]) > 1:
                timesteps_strings.append("All time steps")
                new_table_val.at[index, TIMESTEP_SELECTION_HEADER] = "All time steps"
            else:
                new_table_val.at[index, TIMESTEP_SELECTION_HEADER] = str(timesteps[0][0])
            for num in timesteps[0]:
                timesteps_strings.append(str(num))
        new_table_val.at[index, CHOICES_COLUMN_HEADER] = timesteps_strings

    def enable_arguments(self, arguments):
        """Enable arguments.

        Args:
            arguments (list): The tool arguments.
        """
        # Limit the Cartesian Grid source to Cells and Scatter Sets to Points
        grid = arguments[ARG_GRID_SOURCE].value
        grid_type = grid.split('/')[0]
        if grid_type == 'Cartesian Grid Data':
            arguments[ARG_DATA_LOCATION].choices = ['Cells']
            arguments[ARG_DATA_LOCATION].value = 'Cells'
        elif grid_type == 'Scatter Data':
            arguments[ARG_DATA_LOCATION].choices = ['Points']
            arguments[ARG_DATA_LOCATION].value = 'Points'
        else:
            arguments[ARG_DATA_LOCATION].choices = ['Cells', 'Points']

        # If the user has selected a grid source, then display the table and other argument fields
        if not self.enable_change_filter_selections(arguments):
            self.previous_table_val = arguments[ARG_TABLE].value.copy()
            return

        self.enable_table_change(arguments)
        self.previous_table_val = arguments[ARG_TABLE].value.copy()

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}

        # Validate grid geometry
        ugrid = self.get_input_grid(arguments[ARG_GRID_SOURCE].value).ugrid
        if arguments[ARG_DATA_LOCATION].value == 'Cells' and ugrid.cell_count == 0:
            errors[arguments[ARG_DATA_LOCATION].name] = "Grid has no cells."

        # Validate input datasets
        df = arguments[ARG_TABLE].value
        create_constant_dataset = False

        try:
            if df.empty:
                create_constant_dataset = True

            if not create_constant_dataset:

                for index, row in df.iterrows():
                    if row[DATASET_HEADER] == "Select a dataset":
                        errors[arguments[ARG_TABLE].name] = "Make sure each row has a dataset selected."
                        return errors

                    if row[DATASET_HEADER] == "No datasets found":
                        errors[arguments[ARG_TABLE].name] = ("No datasets found for "
                                                             f"{arguments[ARG_DATA_LOCATION].value.lower()} on grid.")
                        raise ValueError

                    grid_source = arguments[ARG_GRID_SOURCE].value
                    dataset = row[DATASET_HEADER]
                    full_name = grid_source + "/" + dataset
                    dataset = self.get_input_dataset(full_name)
                    self.datasets_list.append(dataset)
                    self.variable_names_list.append(row[VARIABLE_HEADER])

                    if row[TIMESTEP_SELECTION_HEADER] == "All time steps":
                        self.timesteps_list.append("All")
                    else:
                        times_array = np.array(self.datasets_list[index - 1].times)
                        if float(row[TIMESTEP_SELECTION_HEADER]) in times_array:
                            timestep_index = np.where(
                                np.array(self.datasets_list[index - 1].times) == float(row[TIMESTEP_SELECTION_HEADER])
                            )
                            self.timesteps_list.append(timestep_index[0][0])

                self._validate_dataset_compatibility(self.datasets_list, arguments, errors)

            # Check if the input expression is empty
            if arguments[ARG_INPUT_EXPRESSION].value.strip() == "":
                errors[arguments[ARG_INPUT_EXPRESSION].name] = "Missing mathematical expression"
                raise ValueError

            # Check if any of the variable names are reserved keywords.
            self._check_for_reserved_variable_names(arguments, errors)

            # Retrieve all provided variable names in the mathematical expression.
            try:
                exp = nec.stringToExpression(arguments[ARG_INPUT_EXPRESSION].value, {}, {})
                ast = nec.expressionToAST(exp)
                expression_variables = list(map(lambda x: x.value, nec.typeCompileAst(ast).allOf('variable')))
            except Exception:
                errors[arguments[ARG_INPUT_EXPRESSION].name] = "Invalid character(s) in mathematical expression"
                raise ValueError

            # Make sure any provided variable names correspond to a chosen dataset
            for variable in expression_variables:
                if variable not in list(df[VARIABLE_HEADER]):
                    errors[arguments[ARG_INPUT_EXPRESSION].name] = "Invalid variable name(s) used in expression"
                    raise ValueError

        except ValueError:
            self.reset_variable_lists()

        return errors

    def reset_variable_lists(self):
        """Resets the tool's variable lists.

        This is needed in order to run the tool again without
        reopening the tool dialog if the user has hit an error.
        """
        self.datasets_list = []
        self.variable_names_list = []
        self.timesteps_list = []

    def _check_for_reserved_variable_names(self, arguments, errors):
        """Check that datasets have the same number of values and the same number of time steps.

        Args:
            arguments (DataFrame): The input dataset table DataFrame
            errors (dict): Dictionary of errors keyed by argument name. Gets modified if there are errors.
        """
        df = arguments[ARG_TABLE].value
        reserved = df[df['Variable Name'].isin(RESERVED_VARIABLE_NAMES)]
        if reserved.empty:
            return
        errors[arguments[ARG_TABLE].name] = (f'Invalid variable names detected: {reserved["Variable Name"].to_list()}'
                                             '\nPlease change the variable name in the dataset table and update any '
                                             'references to it in the mathematical expression.')
        raise ValueError

    def _validate_dataset_compatibility(self, datasets, arguments, errors):
        """Check that datasets have the same number of values and the same number of time steps.

        Args:
            datasets(list): list of datasets
            arguments (list): The tool arguments.
            errors (dict): Dictionary of errors keyed by argument name. Gets modified if there are errors.
        """
        master_times = None
        master_ref_time = None

        ugrid = self.get_input_grid(arguments[ARG_GRID_SOURCE].value).ugrid

        for index, dataset in enumerate(datasets):
            table = arguments[ARG_TABLE].value
            array = table[TIMESTEP_SELECTION_HEADER].array
            choice = array[index]

            if choice == "All time steps":
                is_constant = len(dataset.times) == 1

                if is_constant:
                    continue  # Constant datasets don't need to match others

                if master_times is None:
                    # First multi-timestep dataset
                    master_times = np.array(dataset.times)
                    master_ref_time = dataset.ref_time
                else:
                    if not np.array_equal(np.array(dataset.times), master_times):
                        errors[arguments[ARG_TABLE].name] = "Datasets' timestep arrays must match."
                        raise ValueError
                    if dataset.ref_time != master_ref_time:
                        errors[arguments[ARG_TABLE].name] = "Datasets' reference times do not all match."
                        raise ValueError

            # Check value arrays
            if self.previous_filter_val == "Cells":
                if np.array(datasets[index].values).shape[1] != ugrid.cell_count:
                    errors[arguments[ARG_TABLE].name] = (f"Number of dataset values on {datasets[index].name} "
                                                         "does not match the number of cells on the grid.")
                    raise ValueError

            else:
                if np.array(datasets[index].values).shape[1] != ugrid.point_count:
                    errors[arguments[ARG_TABLE].name] = (f"Number of dataset values on {datasets[index].name} "
                                                         "does not match the number of points on the grid.")
                    raise ValueError

            # Check activity arrays are all valid
            activity_type = get_activity_type(ugrid, datasets[index].activity)

            if activity_type == "invalid":
                errors[arguments[ARG_TABLE].name] = f"Invalid activity array on {datasets[index].name}"
                raise ValueError

            if activity_type:
                if self.previous_filter_val.lower() == "cells" and activity_type != "cells":
                    errors[arguments[ARG_TABLE].name] = ("Non-matching activity array on "
                                                         f"{datasets[index].name}.\nFound points activity "
                                                         "when activity array should be on cells")
                    raise ValueError

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        grid = self.get_input_grid(arguments[ARG_GRID_SOURCE].value)
        ugrid = grid.ugrid
        geom_uuid = grid.uuid
        expression = arguments[ARG_INPUT_EXPRESSION].text_value
        output_name = arguments[ARG_OUTPUT_DATASET].text_value

        location = arguments[ARG_DATA_LOCATION].value

        # Check if any datasets have been selected
        if len(self.datasets_list) > 0:
            dataset_1 = self.datasets_list[0]
            output_dataset = self.get_output_dataset_writer(
                name=output_name,
                geom_uuid=geom_uuid,
                num_components=1,
                ref_time=dataset_1.ref_time,
                time_units=dataset_1.time_units,
                location=location,
                use_activity_as_null=True
            )
        # If no datasets have been selected
        else:
            output_dataset = self.get_output_dataset_writer(
                name=output_name,
                geom_uuid=geom_uuid,
                num_components=1,
                location=location,
                use_activity_as_null=True
            )

        dataset_calculator = DatasetCalculator(self.datasets_list,
                                               self.variable_names_list,
                                               self.timesteps_list,
                                               ugrid,
                                               expression,
                                               output_dataset,
                                               self.logger)

        result = dataset_calculator.calculate()

        # Check for an error in the tool results
        if result.result_type == "error":
            # Prepare the error message
            fail_string = ("\nInvalid values in the results of the mathematical expression\n"
                           "i.e., dividing by zero, square root of negative numbers\n")

            # Check if the mathematical expression wasn't a constant value expression
            if result.value_count == "multiple":
                timestep_count = 0
                for timestep in result.error_info:
                    # The error message will only have up to 5 timesteps displayed, after 5, just "..." will be shown
                    if timestep_count == 5:
                        fail_string += "...\n"
                        break

                    values = result.error_info.get(timestep)
                    # Grab just the first 5 point or cell ids
                    displayed_values = ", ".join(map(str, values[:5]))
                    if len(values) > 5:
                        # The error message will only have up to 5 point or cell ids displayed, after 5,
                        # just "..." will be shown.
                        displayed_values += "..."

                    # Remove 's' from end of location to use in error message
                    location_type = location[0:-1]
                    fail_string += f"Timestep: {timestep}, {location_type} id(s): {displayed_values}\n"
                    timestep_count += 1

            # Fail the tool from running, and display the error message to the user
            self.fail(fail_string)

        elif result.result_type == "success":
            # Send the dataset back to XMS
            self.set_output_dataset(output_dataset)
