"""Gssha2dUgridTool class."""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules
import math

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.tool_core import ALLOW_ONLY_COVERAGE_TYPE, ALLOW_ONLY_MODEL_NAME, Argument, IoDirection, Tool

# 4. Local modules
from xms.gssha.components import dmi_util
from xms.gssha.mapping import map_util
from xms.gssha.tools import tool_util
from xms.gssha.tools.algorithms import gssha_2d_ugrid_creator

# Constants
ARG_INPUT_COVERAGE = 0
ARG_INPUT_CELL_SIZE = 1
ARG_INPUT_GRID_NAME = 2
ARG_INPUT_USE_RASTER = 3
ARG_INPUT_RASTER = 4
ARG_INPUT_USE_STREAMS = 5
ARG_OUTPUT_GRID = 6

MAX_CELLS = int(1e7)  # Maximum number of grid cells
DEFAULT_CELLS_IN_XY = 60  # Default number of cells in x or y used to determine default cell size


class Gssha2dUgridTool(Tool):
    r"""A tool to create a 2D UGrid oriented the way GSSHA requires, given a coverage and cell size.

    ::

        *-- j
        |  *----*----*----*
        i  |    |    |    |
           *----*----*----*
           |    |    |    |
           *----*----*----*

    """
    def __init__(self) -> None:
        """Initializes the class."""
        super().__init__(name='Create GSSHA 2D UGrid')
        self._query = None

        # For testing
        # import os
        # os.environ['XMSTOOL_GUI_TESTING'] = 'YES'

    def initial_arguments(self) -> list[Argument]:
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        cov_filters = {ALLOW_ONLY_MODEL_NAME: 'GSSHA', ALLOW_ONLY_COVERAGE_TYPE: 'Boundary Conditions'}
        bc_coverage_node = dmi_util.get_default_bc_coverage_node(self._query.project_tree)
        default_cov_path = tool_util.argument_path_from_node(bc_coverage_node)
        default_cell_size = _get_default_cell_size(self._query, bc_coverage_node)
        cov_desc = 'Input GSSHA Boundary Conditions coverage'
        arguments = [
            self.coverage_argument(name='input_cov', description=cov_desc, value=default_cov_path, filters=cov_filters),
            self.float_argument(name='cell_size', description='Cell size', value=default_cell_size, min_value=0.0),
            self.string_argument(name='output_grid_name', description='Output grid name', optional=True),
            self.bool_argument(name='use_raster', description='Get cell elevations from raster', value=True),
            self.raster_argument(name='input_raster', description='Raster', optional=True),
            self.bool_argument(name='use_streams', description='Include streams in cell elevations', value=True),
            self.grid_argument(name='output_grid', hide=True, optional=True, io_direction=IoDirection.OUTPUT)
        ]
        return arguments

    def set_data_handler(self, data_handler) -> None:
        """Set up query attribute if we have a XMSDataHandler."""
        super().set_data_handler(data_handler)
        if hasattr(self._data_handler, "_query"):
            self._query = self._data_handler._query

    def validate_arguments(self, arguments: list[Argument]) -> dict[str, str]:
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors: dict[str, str] = {}

        # Validate input coverage
        argument = arguments[ARG_INPUT_COVERAGE]
        coverage = self.get_input_coverage(argument.text_value)
        extents = map_util.coverage_extents(coverage)
        if not self._valid_coverage_extents(extents, argument.name, errors):
            return errors

        # Validate cell size
        argument = arguments[ARG_INPUT_CELL_SIZE]
        cell_size = float(argument.text_value)
        if not self._valid_num_cells(cell_size, extents, argument.name, errors):
            return errors

        # Validate raster
        arg_use_raster = arguments[ARG_INPUT_USE_RASTER]
        arg_raster = arguments[ARG_INPUT_RASTER]
        if arg_use_raster.value and not arg_raster.value:
            errors[arg_raster.name] = 'Using raster but no raster chosen.'
            return errors

        return errors

    def _valid_coverage_extents(self, extents: map_util.Extents, key: str, errors: dict[str, str]) -> bool:
        """Returns True if the coverage has valid extents (it's not empty)."""
        if extents[0][0] == float('inf'):
            errors[key] = 'Coverage not found or is empty.'
            return False
        return True

    def _valid_num_cells(self, cell_size: float, extents: map_util.Extents, key: str, errors: dict[str, str]) -> bool:
        """Returns True if the cell size results in a valid number of cells."""
        if cell_size <= 0.0:
            errors[key] = 'Cell size must be greater than 0.0.'
            return False

        ncells_x = gssha_2d_ugrid_creator.num_cells_from_cell_size(extents[1][0] - extents[0][0], cell_size)
        ncells_y = gssha_2d_ugrid_creator.num_cells_from_cell_size(extents[1][1] - extents[0][1], cell_size)
        total_cells = ncells_x * ncells_y
        if total_cells > MAX_CELLS:
            # Use locale to put thousands separator in number
            import locale
            locale.setlocale(locale.LC_ALL, '')
            errors[key] = f'Cell size would result in {total_cells:n} cells (max is {MAX_CELLS:n}).'
            return False
        return True

    def run(self, arguments: list[Argument]) -> None:
        """Override to run the tool.

        Args:
            arguments: The tool arguments.
        """
        # Get input argument values
        coverage = self.get_input_coverage(arguments[ARG_INPUT_COVERAGE].text_value)
        cell_size = float(arguments[ARG_INPUT_CELL_SIZE].text_value)
        grid_name = arguments[ARG_INPUT_GRID_NAME].value
        use_raster = arguments[ARG_INPUT_USE_RASTER].value
        raster = self.get_input_raster(str(arguments[ARG_INPUT_RASTER].value)) if use_raster else None
        use_streams = bool(arguments[ARG_INPUT_USE_STREAMS].value)

        # Build grid
        co_grid_2d = gssha_2d_ugrid_creator.build_co_grid(
            self._query, coverage, cell_size, raster, self.vertical_units, use_streams, self.logger
        )

        # Set output grid
        arguments[ARG_OUTPUT_GRID].value = grid_name if grid_name else coverage.name
        self.set_output_grid(co_grid_2d, arguments[ARG_OUTPUT_GRID])


def _get_default_cell_size(query: Query, bc_coverage_node) -> float:
    """Returns a good, default cell size based on the extents of the coverage."""
    if not bc_coverage_node:
        return 1.0

    coverage = query.item_with_uuid(bc_coverage_node.uuid)
    mn, mx = map_util.coverage_extents(coverage)
    if mn[0] == float('inf'):
        return 1.0

    x_size = mx[0] - mn[0]
    y_size = mx[1] - mn[1]
    min_xy_size = min(x_size, y_size)
    cell_size = math.floor(min_xy_size / DEFAULT_CELLS_IN_XY)
    return cell_size
