"""UGrid2dFromUGrid3dTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules

# 3. Aquaveo modules
from xms.constraint import UGrid2dFromUGrid3dCreator, UnconstrainedGrid
from xms.tool_core import IoDirection, Tool

# 4. Local modules


ARG_INPUT_3D_UGRID = 0
ARG_INPUT_TOP_OR_BOTTOM = 1
ARG_INPUT_2D_UGRID_NAME = 2
ARG_OUTPUT_2D_UGRID = 3


class UGrid2dFromUGrid3dTool(Tool):
    """Tool to create a 2D UGrid from the top or bottom of a 3D UGrid with point/cell ordering the same."""

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='2D UGrid from 3D UGrid')
        # TODO: Get rid of this after testing
        # import os
        # os.environ['XMSTOOL_GUI_TESTING'] = 'YES'

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.grid_argument(name='ugrid_3d', description='3D UGrid', optional=False,
                               io_direction=IoDirection.INPUT),
            self.string_argument(name='top_or_bottom', description='Top or bottom', value='Top',
                                 choices=['Top', 'Bottom'], optional=False),
            self.string_argument(name='ugrid_2d_name', description='2D UGrid name', value='',
                                 optional=True),
            self.grid_argument(name='ugrid_2d', description='The 2D UGrid', hide=True, optional=True,
                               io_direction=IoDirection.OUTPUT)
        ]
        return arguments

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors: dict[str, str] = {}
        self._validate_input_ugrid(errors, arguments[ARG_INPUT_3D_UGRID])
        return errors

    def _validate_input_ugrid(self, errors, argument):
        """Validate grid is specified and 3D.

        Args:
            errors (dict): Dictionary of errors keyed by argument name.
            argument (GridArgument): The grid argument.
        """
        key = argument.name
        if argument.value is None:
            return

        cogrid = self.get_input_grid(argument.text_value)
        if not cogrid:
            errors[key] = 'Could not read grid.'
            return

        if cogrid.ugrid.cell_count <= 0:
            errors[key] = 'Grid has no cells.'
        if not cogrid.check_all_cells_3d():
            errors[key] = 'Grid cells must all be 3D.'

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        # Get arguments
        name_3d = arguments[ARG_INPUT_3D_UGRID].text_value
        cogrid_3d = self.get_input_grid(name_3d)
        top_or_bottom = arguments[ARG_INPUT_TOP_OR_BOTTOM].value
        name_2d = arguments[ARG_INPUT_2D_UGRID_NAME].value

        # Build the 2d grid
        creator = UGrid2dFromUGrid3dCreator()
        cogrid_2d = creator.create_2d_cogrid(cogrid_3d, top_or_bottom)
        cogrid_2d = UnconstrainedGrid(ugrid=cogrid_2d.ugrid)
        cogrid_2d.uuid = creator.get_uuid()

        # Set the output grid with the right name
        arguments[ARG_OUTPUT_2D_UGRID].value = self._get_ugrid_2d_name(name_2d, name_3d)
        self.set_output_grid(cogrid_2d, arguments[ARG_OUTPUT_2D_UGRID])

    @staticmethod
    def _get_ugrid_2d_name(name_2d, name_3d):
        """Returns a name for the 2D ugrid.

        If they provided a name (name_2d) it is just returned. Otherwise, we append "2d" to 3D ugrid name.
        E.g.: "blah/blah/my 3D ugrid" -> "my 3D ugrid 2d"

        Args:
            name_2d (str): Name of the 2D UGrid. If provided, it will just be returned.
            name_3d (str): Name of the 3D UGrid.

        Returns:
            (str): Name of the 2D ugrid.
        """
        if name_2d:
            return name_2d
        else:  # Use 3D grid name + ' 2D'
            words = name_3d.split('/')
            name_2d = f'{words[-1]} 2d'
            return name_2d
