"""UGridElevationsTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os

# 2. Third party modules

# 3. Aquaveo modules
from xms.constraint import UnconstrainedGrid
from xms.tool_core import IoDirection, Tool

# 4. Local modules


class CellElevationsTool(Tool):
    """Tool to convert 2D UGrids to UGrids with average cell elevations."""
    ARG_INPUT_GRID = 0
    ARG_OUTPUT_GRID = 1

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Cell Elevations')
        self._args = None
        self._input_cogrid = None
        self._out_grid = None
        self._ug_name = ''

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.grid_argument(name='input_grid', description='Input grid'),
            self.grid_argument(name='output_grid', description='Output grid name', io_direction=IoDirection.OUTPUT),
        ]
        return arguments

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}

        # Make sure output name specified
        self._validate_input_grid(errors, arguments[self.ARG_INPUT_GRID])

        if arguments[self.ARG_OUTPUT_GRID].text_value == '':
            errors[arguments[self.ARG_OUTPUT_GRID].name] = 'Grid name not specified.'

        return errors

    def _validate_input_grid(self, errors, argument):
        """Validate grid is specified and 2D.

        Args:
            errors (dict): Dictionary of errors keyed by argument name.
            argument (GridArgument): The grid argument.
        """
        if argument.text_value == '':
            errors[argument.name] = 'Grid name not specified.'
            return

        self._ug_name = argument.text_value
        self._input_cogrid = self.get_input_grid(self._ug_name)

        if not self._input_cogrid:
            errors[argument.name] = 'Could not read grid.'
            return

        if not self._input_cogrid.check_all_cells_2d():
            errors[argument.name] = 'Grid cells must all be 2D.'
            return

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        self._args = arguments
        self._get_inputs()
        self._create_output_grid()
        self._average_input_cells()
        self._set_outputs()

    def _set_outputs(self):
        """Set the outputs from the tool."""
        self.set_output_grid(self._out_grid, self._args[self.ARG_OUTPUT_GRID])

    def _get_inputs(self):
        """Get the inputs from the arguments."""
        user_input = self._args[self.ARG_OUTPUT_GRID].text_value
        self._args[self.ARG_OUTPUT_GRID].value = user_input if user_input else os.path.basename(self._ug_name)

    def _create_output_grid(self):
        """Creates the output grid."""
        self._out_grid = UnconstrainedGrid(ugrid=self._input_cogrid.ugrid)

    def _average_input_cells(self):
        """Finds average Z-values of input cells and appends to output."""
        ugrid = self._out_grid.ugrid

        cell_elevations = []
        for cell_id in range(ugrid.cell_count):
            point_elevation_sum = 0
            locations = ugrid.get_cell_locations(cell_id)
            for location in locations:
                point_elevation_sum += location[2]
            cell_elevations.append(point_elevation_sum / len(locations))

        # Add a list of all the average Z-values of each cell (set of cell points)
        self._out_grid.cell_elevations = cell_elevations
