"""Merge2dUGridsTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os
import uuid

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.constraint import GridType, UGrid2d, UnconstrainedGrid
from xms.constraint.ugrid_activity import CellToPointActivityCalculator
from xms.gdal.utilities import gdal_utils as gu
from xms.grid.ugrid import UGrid
from xms.tool_core import IoDirection, Tool

# 4. Local modules
from xms.tool.algorithms.ugrids import UGrid2dMerger

XM_NODATA = -999999
ARG_PRIMARY_GRID = 0
ARG_SECONDARY_GRID = 1
ARG_POINT_TOLERANCE = 2
ARG_BUFFER_OPTION = 3
ARG_BUFFER_DISTANCE = 4
ARG_STITCH_GRIDS = 5
ARG_MERGED_GRID = 6


class Merge2dUGridsTool(Tool):
    """Tool to merge two 2D grids."""

    def __init__(self, name='Merge 2D UGrids'):
        """Initializes the class."""
        super().__init__(name=name)
        self._force_ugrid = True
        self._geom_txt = 'UGrid'
        self._geom_txt_plural = 'UGrids'

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.grid_argument(name='primary_grid', description=f'Priority {self._geom_txt}'),
            self.grid_argument(name='secondary_grid', description=f'Secondary {self._geom_txt}'),
            self.float_argument(name='point_tolerance', description='Duplicate point tolerance (ft/m)', value=1.0e-9,
                                min_value=0.0),
            self.string_argument(name='buffer_option', description='Buffer distance option', value='Default',
                                 choices=['Default', 'Specified']),
            self.float_argument(name='buffer_distance', description='Buffer distance (ft/m)', min_value=0.0),
            self.bool_argument(name='stitch_grids',
                               description=f'Stitch non-overlapping {self._geom_txt_plural} '
                               'with matching boundary points', value=False),
            self.grid_argument(name='merged_grid', description=f'Merged {self._geom_txt}',
                               io_direction=IoDirection.OUTPUT, optional=True)
        ]
        return arguments

    def enable_arguments(self, arguments):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.
        """
        arguments[ARG_BUFFER_DISTANCE].show = arguments[ARG_BUFFER_OPTION].value == 'Specified'

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}

        # Validate primary and secondary grids are specified and 2D
        self._validate_input_grid(errors, arguments[ARG_PRIMARY_GRID])
        self._validate_input_grid(errors, arguments[ARG_SECONDARY_GRID])

        # Make sure output name specified
        if arguments[ARG_MERGED_GRID].text_value == '':
            g1 = arguments[ARG_PRIMARY_GRID].text_value.split('/')[-1].replace(' ', '_')
            g2 = arguments[ARG_SECONDARY_GRID].text_value.split('/')[-1].replace(' ', '_')
            arguments[ARG_MERGED_GRID].value = f'Merged_{g1}_{g2}'
        return errors

    def _validate_input_grid(self, errors, argument):
        """Validate grids are specified and 2D.

        Args:
            errors (dict): Dictionary of errors keyed by argument name.
            argument (GridArgument): The grid argument.
        """
        key = argument.name
        grid = self.get_input_grid(argument.text_value)
        if not grid:
            errors[key] = f'Could not read {self._geom_txt}.'
        else:
            if not grid.check_all_cells_2d():
                errors[key] = 'Must have all 2D cells.'

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        factor = 1.0
        if gu.valid_wkt(self.default_wkt):
            # look in GdalUtility.cpp - LaUnit imUnitFromProjectionWKT(const std::string wkt)
            sr = gu.wkt_to_sr(self.default_wkt)
            if sr.IsGeographic():
                if sr.GetAngularUnitsName().upper() in ['DEGREE', 'DS']:
                    # rough conversion from meters to degrees
                    factor = 1.0 / 111000.0

        primary_ugrid = self.get_input_grid(arguments[ARG_PRIMARY_GRID].text_value)
        secondary_ugrid = self.get_input_grid(arguments[ARG_SECONDARY_GRID].text_value)
        point_tolerance = arguments[ARG_POINT_TOLERANCE].value * factor
        buffer_distance = None
        if arguments[ARG_BUFFER_OPTION].value == 'Specified':
            buffer_distance = arguments[ARG_BUFFER_DISTANCE].value * factor
        stitch = arguments[ARG_STITCH_GRIDS].value

        merger = UGrid2dMerger(primary_ugrid, secondary_ugrid, point_tolerance, buffer_distance,
                               self.logger, stitch_grids=stitch)
        merged_xm_ugrid = merger.merge_grids()
        if merged_xm_ugrid:
            merged_co_grid = UnconstrainedGrid(ugrid=merged_xm_ugrid)
            if primary_ugrid.grid_type == GridType.ugrid_2d and secondary_ugrid.grid_type == GridType.ugrid_2d:
                mesh_types = [UGrid.cell_type_enum.QUAD, UGrid.cell_type_enum.PIXEL,
                              UGrid.cell_type_enum.TRIANGLE]
                if merged_co_grid.check_all_cells_are_of_types(mesh_types):
                    merged_co_grid = UGrid2d(ugrid=merged_xm_ugrid)
            self._merged_grid_uuid = str(uuid.uuid4())
            merged_co_grid.uuid = self._merged_grid_uuid
            self.set_output_grid(merged_co_grid, arguments[ARG_MERGED_GRID], force_ugrid=self._force_ugrid)

            primary_datasets = self.get_grid_datasets(arguments[ARG_PRIMARY_GRID].value)
            secondary_datasets = self.get_grid_datasets(arguments[ARG_SECONDARY_GRID].value)
            if len(primary_datasets) < 2 or len(secondary_datasets) < 2:
                return
            if not merger.can_merge_datasets():
                msg = (f'Unable to merge datasets because new points were added to the merged {self._geom_txt} that '
                       f'are not present in the primary or secondary {self._geom_txt}.')
                self.logger.info(msg)
                return

            # datasets must match in name, timesteps, and location to be merged
            primary_name = arguments[ARG_PRIMARY_GRID].text_value
            secondary_name = arguments[ARG_SECONDARY_GRID].text_value
            for p_dset in primary_datasets:
                s_dset = p_dset.replace(primary_name, secondary_name)
                if s_dset not in secondary_datasets:
                    continue
                p_reader = self.get_input_dataset(p_dset)
                s_reader = self.get_input_dataset(s_dset)
                null_value = p_reader.null_value if p_reader.null_value else float(XM_NODATA)
                if p_reader.location == s_reader.location and p_reader.ref_time == s_reader.ref_time:
                    dset_name = os.path.basename(p_dset)
                    writer = self.get_output_dataset_writer(
                        name=dset_name,
                        geom_uuid=self._merged_grid_uuid,
                        num_components=p_reader.num_components,
                        ref_time=p_reader.ref_time,
                        time_units=p_reader.time_units,
                        null_value=null_value,
                        use_activity_as_null=True,
                        location=p_reader.location
                    )

                    # do we need activity calculators?
                    if p_reader.location == 'points':
                        p_reader.activity_calculator = CellToPointActivityCalculator(primary_ugrid.ugrid)
                        s_reader.activity_calculator = CellToPointActivityCalculator(secondary_ugrid.ugrid)

                    # is this the elevation dataset?
                    if p_reader.num_components == 1:
                        p_values = p_reader.timestep_with_activity(0)[0]
                        p_elevs = np.fromiter(primary_ugrid.point_elevations, dtype=p_values.dtype)
                        if np.array_equal(p_elevs, p_values):
                            s_values = s_reader.timestep_with_activity(0)[0]
                            s_elevs = np.fromiter(secondary_ugrid.point_elevations, dtype=s_values.dtype)
                            if np.array_equal(s_values, s_elevs):
                                # don't make a copy of the elevation dataset
                                continue

                    if merger.merge_datasets(p_reader, s_reader, writer):
                        self.set_output_dataset(writer)
                        self.logger.info(f'Merged {dset_name}.')
