"""Mesh2dFromUGrid2dTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.constraint.grid import GridType
from xms.tool_core import dataset_filters, IoDirection, Tool

# 4. Local modules
from xms.tool.algorithms.mesh_2d.mesh_from_ugrid import MeshFromUGrid

ARG_INPUT_UGRID = 0
ARG_INPUT_SOURCE = 1
ARG_OUTPUT_MESH = 2


class Mesh2dFromUGrid2dTool(Tool):
    """Tool to convert 2D UGrids to SMS 2D mesh module geometries."""
    SOURCE_OPT_CENTROIDS = 0
    SOURCE_OPT_POINTS = 1

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='2D Mesh from 2D UGrid')
        self._mesh_name = ''
        self._source_opt = self.SOURCE_OPT_CENTROIDS
        self._input_cogrid = None
        self._input_ugrid = None
        self._output_mesh = None
        self._output_mesh_uuid = ''
        self._args = []
        self._available_datasets = []
        self._datasets = []
        self._ug_name = ''
        self._new_cellstream = None
        self._new_cell_idx = None

    def _validate_input_grid(self, argument, node_based):
        """Validate grid is specified and 2D.

        Args:
            argument (GridArgument): The grid argument.
            node_based (bool): True if source is UGrid points, False if cell centers

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}
        self._input_cogrid = self.get_input_grid(argument.text_value)
        if self._input_cogrid:
            # If we have a grid, perform all the checks
            grid_errors = []
            self._input_ugrid = self._input_cogrid.ugrid  # Only unwrap this once.
            got_node_data = node_based and self._input_ugrid.point_count > 0
            got_cell_data = not node_based and self._input_ugrid.cell_count > 0
            if not got_node_data and not got_cell_data:
                grid_errors.append('No source data to convert.')
            if self._input_cogrid.grid_type in [GridType.ugrid_3d, GridType.rectilinear_3d, GridType.quadtree_3d]:
                grid_errors.append('Must have all 2D cells.')
            error = '\n'.join(grid_errors)
            if error:
                errors[argument.name] = error
        return errors

    def _get_inputs(self):
        """Get the user inputs."""
        self.logger.info('Retrieving input data from SMS...')
        # If user didn't give an output mesh name, use the name of the input UGrid.
        user_input = self._args[ARG_OUTPUT_MESH].text_value
        self._ug_name = self._args[ARG_INPUT_UGRID].text_value
        ug_name = os.path.basename(self._ug_name)
        self._mesh_name = os.path.basename(user_input) if user_input else ug_name
        self._args[ARG_OUTPUT_MESH].value = self._mesh_name
        # Switch source to UGrid points if not default of cell centers.
        if self._args[ARG_INPUT_SOURCE].text_value == 'UGrid points':
            self._source_opt = self.SOURCE_OPT_POINTS
            filters = [dataset_filters.ALLOW_ONLY_POINT_MAPPED]
        else:
            filters = [dataset_filters.ALLOW_ONLY_CELL_MAPPED]
        self._available_datasets = self.get_grid_datasets(self._ug_name, filters=filters)

    def _create_datasets(self):
        """Copies compatible datasets to the output mesh."""
        self.logger.info('Creating datasets...')
        old_geom_uuid = self._input_cogrid.uuid
        datasets = []
        ug_name = self._ug_name
        for dset_name in self._available_datasets:
            reader = self.get_input_dataset(dset_name)
            if reader.geom_uuid == old_geom_uuid:  # This a descendant of our original geometry.
                new_dataset_name = dset_name.replace(f'{ug_name}/', '')
                self.logger.info(f'Copying "{dset_name}" to output mesh...')
                writer = self.get_output_dataset_writer(
                    name=new_dataset_name,
                    geom_uuid=self._output_mesh_uuid,
                    num_components=reader.num_components,
                    ref_time=reader.ref_time,
                    time_units=reader.time_units,
                    null_value=reader.null_value,
                )
                ds_data = reader.values[:]
                activity = reader.activity[:] if reader.activity is not None else None
                if activity is not None:
                    writer.use_activity_as_null = True
                    if reader.location in ['cells', ('cells', 'cells')]:  # cell data, cell activity
                        # don't create point activity array for 2d mesh
                        writer.use_activity_as_null = False
                    elif reader.num_activity_values != reader.num_values:  # point data, cell activity
                        writer.timestep_maxs = reader.maxs
                        writer.timestep_mins = reader.mins
                        if self._new_cellstream:  # update the activity with the newly created cells
                            num_cells = self._output_mesh.ugrid.cell_count
                            new_act = np.ndarray((activity.shape[0], num_cells), dtype=int)
                            idxs = self._new_cell_idx
                            for i in range(activity.shape[0]):
                                ts_act = np.array([activity[i][idxs[j][1]] for j in range(len(idxs))])
                                new_act[i] = ts_act
                            activity = new_act
                    else:  # point data, point activity
                        writer.use_activity_as_null = False
                    if not writer.use_activity_as_null:
                        writer.null_value = -999.0
                        ds_data[activity == 0] = writer.null_value

                writer.write_xmdf_dataset(times=reader.times[:], data=ds_data, activity=activity)
                datasets.append(writer)
        self._datasets = datasets

    def _set_outputs(self):
        """Set outputs from the tool."""
        self.set_output_grid(self._output_mesh, self._args[ARG_OUTPUT_MESH], force_ugrid=False)
        if self._datasets:
            for ds in self._datasets:
                self.set_output_dataset(ds)
        else:
            self.logger.warning('No datasets created from source UGrid.')

    def _create_ugrid(self):
        """Create the mesh."""
        converter = MeshFromUGrid()
        self._output_mesh, self._new_cell_idx = (converter.convert(self._source_opt, self._input_ugrid, self.logger,
                                                                   self._input_cogrid))

        self._new_cellstream = self._output_mesh.ugrid.cellstream
        self._output_mesh_uuid = self._output_mesh.uuid

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.grid_argument(name='input_ugrid', description='2D UGrid', io_direction=IoDirection.INPUT),
            self.string_argument(name='source_option', description='Create from',
                                 choices=['UGrid cell centers', 'UGrid points'], value='UGrid cell centers'),
            self.grid_argument(name='mesh', description='Mesh name', optional=True, io_direction=IoDirection.OUTPUT),
        ]
        return arguments

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        grid_arg = arguments[ARG_INPUT_UGRID]
        errors = {}
        if grid_arg.value is not None:  # Case of no selected UGrid already handled in base
            node_based = arguments[ARG_INPUT_SOURCE].text_value == 'UGrid points'
            errors = self._validate_input_grid(grid_arg, node_based)
        return errors

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        self.logger.info(f'PID: {os.getpid()}')
        self._args = arguments
        self._get_inputs()
        self._create_ugrid()
        self._create_datasets()
        self._set_outputs()
