"""tool utilities."""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules

# 2. Third party modules
import numpy as np
import numpy.typing as npt

# 3. Aquaveo modules
from xms.api.tree import tree_util, TreeNode
from xms.gdal.rasters import RasterInput
from xms.gdal.utilities import gdal_utils
from xms.gdal.vectors import VectorInput
from xms.grid.ugrid import UGrid

# 4. Local modules
from xms.gssha.misc.type_aliases import IntArray, Pt3dArray

# Constants
XM_NODATA = -9999999


def argument_path_from_node(tree_node: TreeNode) -> str:
    """Returns tool argument compatible tree path to the node ('Project/' is stripped)."""
    if not tree_node:
        return ''

    path = tree_util.tree_path(tree_node)
    if path.startswith('Project/'):
        path = path[len('Project/'):]
    return path


def compute_cell_centers(ugrid: UGrid) -> Pt3dArray:
    """Returns the cell centers as a numpy array.

    Args:
        ugrid: The UGrid.

    Returns:
        See description.
    """
    return np.asarray([ugrid.get_cell_centroid(i)[1] for i in range(ugrid.cell_count)])


def raster_values_at_cells(
    cell_centers: Pt3dArray, on_off_cells: IntArray, raster: RasterInput, vertical_units: str
) -> 'npt.NDArray':
    """Returns the raster value at every cell center.

    Args:
        cell_centers: Cell center coordinates
        on_off_cells: Model on/off cells.
        raster: The raster.
        vertical_units: Vertical units.

    Returns:
        See description.
    """
    # Create the right type of array
    if raster.data_type in {
        RasterInput.GDT_Float32, RasterInput.GDT_Float64, RasterInput.GDT_CFloat32, RasterInput.GDT_CFloat64
    }:
        values = np.zeros(len(cell_centers))
        off_value = float(XM_NODATA)  # Use a float for off cells when using floats (like elevation)
    else:
        values = np.zeros(len(cell_centers), dtype=int)
        off_value = 0  # Use 0 for off cell when using ints (like land use)

    # Get the raster pixel value at each cell center (no interpolation or averaging is done)
    vm = gdal_utils.get_vertical_multiplier(gdal_utils.get_vert_unit_from_wkt(raster.wkt), vertical_units)
    for cell_idx, cell_center in enumerate(cell_centers):
        if on_off_cells[cell_idx]:
            x_off, y_off = raster.coordinate_to_pixel(cell_center[0], cell_center[1])
            if x_off >= 0 and y_off >= 0:
                raster_value = raster.get_raster_values(x_off, y_off, 1, 1, resample_alg='nearest_neighbor')
                if raster_value != raster.nodata_value:
                    values[cell_idx] = raster_value[0] * vm
            else:
                values[cell_idx] = raster.nodata_value if raster.nodata_value else XM_NODATA
        else:
            values[cell_idx] = off_value
    return values


def is_polygon_shapefile(vi: VectorInput) -> bool:
    """Returns True if the file is a polygon shapefile, else returns False and adds to errors.

    Args:
        vi: VectorInput object.

    Returns:
        See description.
    """
    return vi.layer_type in {
        VectorInput.wkbPolygon, VectorInput.wkbPolygonM, VectorInput.wkbPolygonZM, VectorInput.wkbPolygon25D
    }
