"""ExportCurvilinearGridTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.constraint.grid import GridType
from xms.grid.ugrid import UGrid as XmUGrid
try:
    from xms.guipy.settings import get_file_browser_directory
    get_file_browser_directory_supported = True
except ImportError:  # pragma no cover - optional import
    get_file_browser_directory_supported = False
from xms.tool_core import ALLOW_ONLY_CELL_MAPPED, ALLOW_ONLY_SCALARS, IoDirection, Tool

# 4. Local modules
from xms.tool.algorithms.ugrids.curvilinear_grid_ij import GridIJCreator
from xms.tool.file_io.curvilinear.ch3d_grid_writer import Ch3dGridWriter
from xms.tool.file_io.curvilinear.efdc_grid_writer import EfdcGridWriter

ARG_INPUT_FORMAT = 0
ARG_INPUT_GRID = 1
ARG_INPUT_GEN_DSETS = 2
ARG_INPUT_I_DSET = 3
ARG_INPUT_J_DSET = 4
ARG_INPUT_DEPTH_DSET = 5
ARG_INPUT_ZROUGH_DSET = 6
ARG_INPUT_VEG_TYPE_DSET = 7
ARG_INPUT_WIND_SHELTER_DSET = 8
ARG_OUTPUT_DIR = 9
ARG_OUTPUT_CH3D_FILENAME = 10
ARG_OUTPUT_DXDY_FILENAME = 11
ARG_OUTPUT_LXLY_FILENAME = 12
ARG_OUTPUT_CELL_FILENAME = 13

CH3D_START_IJ = 1
EFDC_START_IJ = 3


class ExportCurvilinearGridTool(Tool):
    """Tool to export a curvilinear grid to a CH3D or EFDC file."""

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Export Curvilinear Grid')
        self._input_cogrid = None
        self._input_ugrid = None
        self._i_dset = None  # Will generate if not specified
        self._j_dset = None  # Will generate if not specified
        # Optional EFDC datasets, default to 0.0
        self._depth_dset = None
        self._zrough_dset = None
        self._veg_type_dset = None
        self._wind_shelter_dset = None

    def _reset_datasets(self):
        """If there is a problem with a selected input dataset or grid, reset references to null."""
        self._i_dset = None
        self._j_dset = None
        self._depth_dset = None
        self._zrough_dset = None
        self._veg_type_dset = None
        self._wind_shelter_dset = None

    def _set_efdc_datasets(self, ij_creator):
        """Add cell datasets to the i-j dataset creator so they get added to the output DataFrame.

        Args:
            ij_creator (GridIJCreator): The grid i-j dataset creator
        """
        num_cells = self._input_ugrid.cell_count
        ij_creator.cell_datasets = {
            'depth': np.full(num_cells, 0.0) if self._depth_dset is None else self._depth_dset.values[0],
            'zrough': np.full(num_cells, 0.0) if self._zrough_dset is None else self._zrough_dset.values[0],
            'veg_type': np.full(num_cells, 0.0) if self._veg_type_dset is None else self._veg_type_dset.values[0],
            'wind_shelter':
                np.full(num_cells, 0) if self._wind_shelter_dset is None else self._wind_shelter_dset.values[0],
        }

    def _set_existing_ij_datasets(self, ij_creator, is_efdc):
        """Set existing i-j input datasets on the i-j dataset builder.

        Args:
            ij_creator (GridIJCreator): The grid i-j dataset creator
            is_efdc (bool): True if reading EFDC format files
        """
        ivals = self._i_dset.values[0]
        jvals = self._j_dset.values[0]
        if is_efdc:
            ivals += 2  # The minimum i or j value for EFDC cells is 3 because of the border buffer zone
            jvals += 2
        ij_creator.set_cell_ij(ivals, jvals)

    def _get_inputs(self, arguments):
        """Get the user inputs and build the cell i-j DataFrame.

        Args:
            arguments (list): The tool arguments.

        Returns:
            pd.DataFrame: The grid cell i-j dataset indexed by 'i' and 'j' with a 'cell_idx' column containing the
            XmUGrid cellstream index for the i-j coordinate pair index. Returns None if error occurred.
        """
        self.logger.info('Retrieving input data...')
        # Get the cell i-coordinate dataset if specified, else generate one.
        ij_creator = GridIJCreator(self._input_ugrid, self.logger)
        is_efdc = arguments[ARG_INPUT_FORMAT].value == 'EFDC'
        start_ij = EFDC_START_IJ if is_efdc else CH3D_START_IJ
        if self._i_dset and self._j_dset:  # Using existing i-j datasets
            self._set_existing_ij_datasets(ij_creator, is_efdc)
        if is_efdc:
            self._set_efdc_datasets(ij_creator)
        return ij_creator.create_ij_dataset(start_ij)

    def _validate_input_grid(self, grid_arg):
        """Ensure the input grid has all 2D quad cells and no disjoint points.

        Args:
            grid_arg (GridArgument): The grid argument.

        Returns:
            dict: The dict of grid error messages, empty if checks passed
        """
        errors = {}
        key = grid_arg.name
        self._input_cogrid = self.get_input_grid(grid_arg.text_value)
        self._input_ugrid = self._input_cogrid.ugrid  # Only unwrap this once.
        grid_errors = []
        is_3d = self._input_cogrid.grid_type in [
            GridType.ugrid_3d, GridType.rectilinear_3d, GridType.quadtree_3d
        ]
        if is_3d or not self._input_cogrid.check_all_cells_are_of_type(XmUGrid.cell_type_enum.QUAD):
            grid_errors.append('Input grid must have all 2D Quad cells.')
        if self._input_cogrid.check_has_disjoint_points():
            grid_errors.append('Input grid cannot have disjoint points.')
        # This can be supported with multi-block
        if not self._input_cogrid.check_contiguous_cells_2d():
            grid_errors.append('Input grid cannot have disjoint regions.')
        # We may want to add a check for valence not equal to 4
        error = '\n'.join(grid_errors)
        if error:
            errors[key] = error
        return errors

    def _validate_input_datasets(self, arguments, errors):
        """Ensure the input datasets have the same number of values as input grid cells, if specified.

        Args:
            arguments (list): The tool arguments.
            errors (dict): Dictionary of errors keyed by argument name.
        """
        self._validate_ij_datasets(arguments, errors)
        if arguments[ARG_INPUT_FORMAT].value == 'EFDC':
            self._validate_efdc_datasets(arguments, errors)

    def _validate_ij_datasets(self, arguments, errors):
        """Ensure the i and j coordinate datasets have the same number of values as input grid cells, if specified.

        Args:
            arguments (list): The tool arguments.
            errors (dict): Dictionary of errors keyed by argument name.
        """
        self._i_dset = self.get_input_dataset(arguments[ARG_INPUT_I_DSET].text_value)
        self._j_dset = self.get_input_dataset(arguments[ARG_INPUT_J_DSET].text_value)
        if self._i_dset:  # Will generate if not specified
            ierrs = []
            if not self._j_dset:  # But have to give the j-coordinate dataset as well if i-coordinate is specified
                ierrs.append('Must specify a cell j-coordinate dataset if the i-coordinate dataset is selected.')
            if self._i_dset.num_values != self._input_ugrid.cell_count:  # Probably should be on the input grid
                ierrs.append('If specified, the cell i-coordinate dataset must have the same number of values as cells '
                             'in the input UGrid.')
            if ierrs:
                errors[arguments[ARG_INPUT_I_DSET].name] = '\n'.join(ierrs)
        if self._j_dset:  # Will generate if not specified
            jerrs = []
            if not self._i_dset:  # But have to give the i-coordinate dataset as well if j-coordinate is specified
                jerrs.append('Must specify a cell an i-coordinate dataset if the j-coordinate dataset is selected.')
            if self._j_dset.num_values != self._input_ugrid.cell_count:  # Probably should be on the input grid
                jerrs.append('If specified, the cell j-coordinate dataset must have the same number of values as cells '
                             'in the input UGrid.')
            if jerrs:
                errors[arguments[ARG_INPUT_J_DSET].name] = '\n'.join(jerrs)

    def _validate_efdc_datasets(self, arguments, errors):
        """Ensure the EFDC input datasets have the same number of values as input grid cells, if specified.

        Notes:
            If an input dataset is not specified, it will be defaulted to 0.0.

        Args:
            arguments (list): The tool arguments.
            errors (dict): Dictionary of errors keyed by argument name.
        """
        num_cells = self._input_ugrid.cell_count
        self._depth_dset = self.get_input_dataset(arguments[ARG_INPUT_DEPTH_DSET].text_value)
        if self._depth_dset and self._depth_dset.num_values != num_cells:
            errors[arguments[ARG_INPUT_DEPTH_DSET].name] = 'If specified, the cell depth dataset must have the same ' \
                                                           'number of values as cells in the input UGrid.'
        self._zrough_dset = self.get_input_dataset(arguments[ARG_INPUT_ZROUGH_DSET].text_value)
        if self._zrough_dset and self._zrough_dset.num_values != num_cells:
            errors[arguments[ARG_INPUT_ZROUGH_DSET].name] = 'If specified, the cell Z roughness dataset must have ' \
                                                            'the same number of values as cells in the input UGrid.'
        self._veg_type_dset = self.get_input_dataset(arguments[ARG_INPUT_VEG_TYPE_DSET].text_value)
        if self._zrough_dset and self._zrough_dset.num_values != num_cells:
            errors[arguments[ARG_INPUT_VEG_TYPE_DSET].name] = 'If specified, the cell vegetation type dataset must ' \
                                                              'have the same number of values as cells in the input ' \
                                                              'UGrid.'
        self._wind_shelter_dset = self.get_input_dataset(arguments[ARG_INPUT_WIND_SHELTER_DSET].text_value)
        if self._wind_shelter_dset and self._wind_shelter_dset.num_values != num_cells:
            errors[arguments[ARG_INPUT_WIND_SHELTER_DSET].name] = 'If specified, the cell wind shelter dataset must ' \
                                                                  'have the same number of values as cells in the ' \
                                                                  'input UGrid.'

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        grid_arg = arguments[ARG_INPUT_GRID]
        errors = {}
        if grid_arg.value is not None:  # Case of no selected UGrid already handled in base
            errors = self._validate_input_grid(grid_arg)
            # Don't try to validate input datasets if we don't have a valid input grid.
            if not errors:
                self._validate_input_datasets(arguments, errors)
            else:
                self._reset_datasets()
        return errors

    def enable_arguments(self, arguments):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.
        """
        # Show/hide file selectors based on format
        is_ch3d = arguments[ARG_INPUT_FORMAT].value == 'CH3D'
        arguments[ARG_OUTPUT_CH3D_FILENAME].hide = not is_ch3d
        arguments[ARG_OUTPUT_DXDY_FILENAME].hide = is_ch3d
        arguments[ARG_OUTPUT_LXLY_FILENAME].hide = is_ch3d
        arguments[ARG_OUTPUT_CELL_FILENAME].hide = is_ch3d
        # Show/hide dataset selectors only applicable to EFDC format
        arguments[ARG_INPUT_DEPTH_DSET].hide = is_ch3d
        arguments[ARG_INPUT_ZROUGH_DSET].hide = is_ch3d
        arguments[ARG_INPUT_VEG_TYPE_DSET].hide = is_ch3d
        arguments[ARG_INPUT_WIND_SHELTER_DSET].hide = is_ch3d
        # Show/hide dataset selectors based on generate cell i-j option
        generate_dsets = arguments[ARG_INPUT_GEN_DSETS].value
        arguments[ARG_INPUT_I_DSET].hide = generate_dsets
        arguments[ARG_INPUT_J_DSET].hide = generate_dsets

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        if get_file_browser_directory_supported:
            start_dir = get_file_browser_directory()
        else:  # pragma no cover - optional import
            start_dir = ''
        dset_filters = [ALLOW_ONLY_SCALARS, ALLOW_ONLY_CELL_MAPPED]
        arguments = [
            self.string_argument(name='format', description='Format of the data file', value='CH3D',
                                 choices=['CH3D', 'EFDC']),
            self.grid_argument(name='input_ugrid', description='Curvilinear UGrid', io_direction=IoDirection.INPUT),
            self.bool_argument(name='generate_ij', description='Generate cell i-j datasets', value=True),
            self.dataset_argument(name='i_dset', description='Cell i-coordinate dataset (optional)', optional=True,
                                  io_direction=IoDirection.INPUT, filters=dset_filters),
            self.dataset_argument(name='j_dset', description='Cell j-coordinate dataset (optional)', optional=True,
                                  io_direction=IoDirection.INPUT, filters=dset_filters),
            self.dataset_argument(name='depth_dset', description='Depth dataset (optional)', optional=True,
                                  io_direction=IoDirection.INPUT, filters=dset_filters),
            self.dataset_argument(name='zrough_dset', description='Z Roughness dataset (optional)', optional=True,
                                  io_direction=IoDirection.INPUT, filters=dset_filters),
            self.dataset_argument(name='veg_type_dset', description='Vegetation type dataset (optional)', optional=True,
                                  io_direction=IoDirection.INPUT, filters=dset_filters),
            self.dataset_argument(name='wind_shelter_dset', description='Wind shelter dataset (optional)',
                                  optional=True, io_direction=IoDirection.INPUT, filters=dset_filters),
            self.string_argument(name='output directory', description='Path to where files will be saved',
                                 value=start_dir),
            self.string_argument(name='ch3d_file', description='CH3D formatted grid file', optional=True,
                                 value='grid.inp'),
            self.string_argument(name='dxdy_file', description='EFDC formatted dxdy.inp file', optional=True,
                                 value='dxdy.inp'),
            self.string_argument(name='lxly_file', description='EFDC formatted lxly.inp file', optional=True,
                                 value='lxly.inp'),
            self.string_argument(name='cellinp_file', description='EFDC formatted cell.inp file',
                                 optional=True, value='cell.inp'),
        ]
        self.enable_arguments(arguments)
        return arguments

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        ij_df = self._get_inputs(arguments)
        if arguments[ARG_INPUT_FORMAT].value == 'CH3D':
            self.logger.info('Writing curvilinear grid to CH3D formatted file...')
            grid_name = os.path.basename(arguments[ARG_INPUT_GRID].value)
            filename = os.path.join(arguments[ARG_OUTPUT_DIR].value, arguments[ARG_OUTPUT_CH3D_FILENAME].value)
            writer = Ch3dGridWriter(filename, self._input_ugrid, ij_df, grid_name, self.logger)
        else:  # arguments[ARG_INPUT_FORMAT].value == 'EFDC':
            self.logger.info('Writing curvilinear grid to EFDC formatted files...')
            writer = EfdcGridWriter(arguments[ARG_OUTPUT_DIR].value, arguments[ARG_OUTPUT_DXDY_FILENAME].value,
                                    arguments[ARG_OUTPUT_LXLY_FILENAME].value,
                                    arguments[ARG_OUTPUT_CELL_FILENAME].value,
                                    self._input_ugrid, ij_df, self._input_cogrid.cell_elevations, self.logger)
        writer.write()
