"""PointSpacingTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.gdal.utilities import gdal_utils as gu
from xms.grid.ugrid import UGrid
from xms.mesher.meshing.mesh_utils import size_function_from_edge_lengths
from xms.tool_core import IoDirection, Tool

# 4. Local modules

ARG_GRID_IN = 0
ARG_DATASET_OUT = 1


class PointSpacingTool(Tool):
    """Tool to calculate average edge length for each grid point."""

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Point Spacing')

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.grid_argument(name='input_grid', description='Input grid'),
            self.dataset_argument(name='output_dataset', description='Output dataset', io_direction=IoDirection.OUTPUT)
        ]
        return arguments

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}

        # Validate primary and secondary grids are specified and 2D
        self._validate_input_grid(errors, arguments[ARG_GRID_IN])

        return errors

    def _validate_input_grid(self, errors, argument):
        """Validate grid is specified and 2D.

        Args:
            errors (dict): Dictionary of errors keyed by argument name.
            argument (GridArgument): The grid argument.
        """
        key = argument.name
        input_ugrid = self.get_input_grid(argument.text_value)
        if input_ugrid is None:
            errors[key] = 'Could not read grid.'
        else:
            if input_ugrid.ugrid.cell_count < 1:
                errors[key] = 'Must have cells defined to calculate spacing.'
            if not input_ugrid.check_all_cells_2d():
                errors[key] = 'Must have all 2D cells.'

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        input_co_grid = self.get_input_grid(arguments[ARG_GRID_IN].text_value)
        ug = input_co_grid.ugrid

        self.logger.info(f'Calculating spacing for {ug.point_count} points.')
        point_spacing = np.array(size_function_from_edge_lengths(ug))
        point_spacing = np.nan_to_num(point_spacing, 0.0)
        disjoint_idxs = np.where(point_spacing == 0.0)[0]
        if len(disjoint_idxs) > 0:
            self.logger.info(f'{len(disjoint_idxs)} disjoint points found and assigned value of 0.0.')
            pt_ids = disjoint_idxs + 1
            self.logger.info(f'Disjoint point ids: {pt_ids}')

        def _write_dataset(name, values):
            builder = self.get_output_dataset_writer(
                name=name,
                geom_uuid=input_co_grid.uuid,
                num_components=1,
            )
            builder.append_timestep(0.0, values)
            builder.appending_finished()
            self.set_output_dataset(builder)

        sr = None if not self.default_wkt else gu.wkt_to_sr(self.default_wkt)
        # Create a place for the output dataset file
        dataset_name = arguments[ARG_DATASET_OUT].text_value
        if sr and sr.IsGeographic():
            dataset_name = f'{dataset_name}_degrees'
        _write_dataset(dataset_name, point_spacing)
        if sr and sr.IsGeographic():
            ug = input_co_grid.ugrid
            locs, utm = gu.convert_lat_lon_pts_to_utm(ug.locations)
            ug = UGrid(locs, ug.cellstream)
            point_spacing = np.array(size_function_from_edge_lengths(ug))
            dataset_name = arguments[ARG_DATASET_OUT].text_value + '_meters'
            _write_dataset(dataset_name, point_spacing)
