"""PrimitiveWeightingTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.gdal.utilities import gdal_utils as gu
from xms.grid.ugrid import UGrid
from xms.mesher.meshing import mesh_utils
from xms.tool_core import IoDirection, Tool

# 4. Local modules

ARG_INPUT_GRID = 0
ARG_INPUT_CRITICAL_AVERAGE_NODE_SPACING = 1
ARG_INPUT_CRITICAL_DEPTH = 2
ARG_INPUT_TAU_DEFAULT = 3
ARG_INPUT_TAU_DEEP = 4
ARG_INPUT_TAU_SHALLOW = 5
ARG_OUTPUT_DATASET = 6

DSET_NULL_VALUE = -999999.0


class PrimitiveWeightingTool(Tool):
    """Tool to compute primitive weighting coefficient to a dataset."""

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Primitive Weighting')
        self._cogrid = None
        self._grid = None
        self._grid_uuid = None
        self._builder = None
        self._dataset_vals = None
        self._critical_spacing = 1750.0
        self._critical_depth = 10.0
        self._tau_default = 0.03
        self._tau_deep = 0.005
        self._tau_shallow = 0.02
        self.log_frequency = 10000

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.grid_argument(name='grid', description='Target grid'),
            self.float_argument(name='critical_spacing', description='Critical average node spacing', value=1750.0,
                                min_value=0.0),
            self.float_argument(name='critical_depth', description='Critical depth', value=10.0,
                                min_value=0.0),
            self.float_argument(name='tau_default', description='Tau default', value=0.03,
                                min_value=0.0),
            self.float_argument(name='tau_deep', description='Tau deep', value=0.005,
                                min_value=0.0),
            self.float_argument(name='tau_shallow', description='Tau shallow', value=0.02,
                                min_value=0.0),
            self.dataset_argument(name='tau0_dataset', description='Primitive weighting dataset',
                                  value='Tau0', io_direction=IoDirection.OUTPUT),
        ]
        self.enable_arguments(arguments)
        return arguments

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}
        # Validate input data
        self._cogrid = self.get_input_grid(arguments[ARG_INPUT_GRID].text_value)
        if not self._cogrid:
            errors[arguments[ARG_INPUT_GRID].name] = 'Could not open target grid.'
        return errors

    def enable_arguments(self, arguments):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.
        """
        pass

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        # Set up some of the grid variables
        self._grid_uuid = self._cogrid.uuid
        self._grid = self._cogrid.ugrid

        # Set up some of the tool arguments
        self._tau_default = float(arguments[ARG_INPUT_TAU_DEFAULT].text_value)
        self._critical_spacing = float(arguments[ARG_INPUT_CRITICAL_AVERAGE_NODE_SPACING].text_value)
        self._critical_depth = float(arguments[ARG_INPUT_CRITICAL_DEPTH].text_value)
        self._tau_deep = float(arguments[ARG_INPUT_TAU_DEEP].text_value)
        self._tau_shallow = float(arguments[ARG_INPUT_TAU_SHALLOW].text_value)

        self._dataset_vals = self._calculate_primitive_weighting()

        # Write out the dataset
        self._setup_output_dataset_builder(arguments)
        self._add_output_datasets()

    def _setup_output_dataset_builder(self, arguments):
        """Set up dataset builders for selected tool outputs.

        Args:
            arguments (list): The tool arguments.
        """
        # Create a place for the output dataset file
        dataset_name = arguments[ARG_OUTPUT_DATASET].text_value
        self._builder = self.get_output_dataset_writer(
            name=dataset_name,
            geom_uuid=self._grid_uuid,
            null_value=DSET_NULL_VALUE,
        )

    def _add_output_datasets(self):
        """Add datasets created by the tool to be sent back to XMS."""
        self.logger.info('Adding output dataset...')
        if self._builder is not None:
            self.logger.info('Writing output primitive weighting dataset to XMDF file...')
            self._builder.write_xmdf_dataset([0.0], [self._dataset_vals])
            # Send the dataset back to XMS
            self.set_output_dataset(self._builder)

    def _calculate_primitive_weighting(self):
        """Calculates the primitive weighting coefficients.

        Returns:
            dataset:  The dataset calculated.
        """
        # Calculate the size function and set the dataset values default
        self.logger.info('Calculating size function...')
        sr = None if not self.default_wkt else gu.wkt_to_sr(self.default_wkt)
        if sr and sr.IsGeographic():
            locs, utm = gu.convert_lat_lon_pts_to_utm(self._grid.locations)
            self._grid = UGrid(locs, self._grid.cellstream)
        size_func = mesh_utils.size_function_from_edge_lengths(self._grid)

        self._fill_array(size_func)
        return self._dataset_vals

    def _fill_array(self, size_func):
        """Calculates the primitive weighting coefficient based on the size function.

        Args:
            size_func (Sequence): The size function for the grid points, parallel with point_locations
        """
        # swap sign of z because tool expects depth
        point_locations = [(p[0], p[1], -p[2]) for p in self._grid.locations]
        num_points = len(point_locations)

        # Set up the default dataset values
        self._dataset_vals = np.full(num_points, self._tau_default)

        # Loop on the grid points and size function
        log_frequency = self.log_frequency
        self.logger.info(f'Processing point 1 of {num_points}...')
        for i, (point, size_val) in enumerate(zip(point_locations, size_func)):
            if (i + 1) % log_frequency == 0:
                self.logger.info(f'Processing point {i + 1} of {num_points}...')

            if size_val < self._critical_spacing:
                self._dataset_vals[i] = self._tau_default
            elif point[2] > self._critical_depth:
                self._dataset_vals[i] = self._tau_deep
            else:
                self._dataset_vals[i] = self._tau_shallow
