"""ChezyFrictionTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os

# 2. Third party modules
import pandas

# 3. Aquaveo modules
from xms.tool_core import ALLOW_ONLY_POINT_MAPPED, ALLOW_ONLY_SCALARS, IoDirection, Tool

# 4. Local modules
from xms.tool.utilities.landuse_codes import get_ccap_chezy_friction_values, get_nlcd_chezy_friction_values
from xms.tool.utilities.landuse_mapper import LanduseMapper

ARG_INPUT_LANDUSE = 0
ARG_INPUT_LANDUSE_TYPE = 1
ARG_INPUT_CSV = 2
ARG_INPUT_GRID = 3
ARG_INPUT_DEFAULT_OPTION = 4
ARG_INPUT_DEFAULT = 5
ARG_INPUT_DEFAULT_DSET = 6
ARG_INPUT_LND = 7
ARG_OUTPUT_DATASET = 8


class ChezyFrictionTool(Tool):
    """Tool to compute Chezy friction to a dataset."""

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Chezy Friction')
        self._cogrid = None
        self._grid = None
        self._grid_uuid = None
        self._builder = None
        self._raster = None
        self._raster_values = None
        self._raster_null_value = -999999.0
        self._mapping_table = {}
        self._dataset_vals = None
        self._default_value = 60.0
        self._dataset_values = None
        self._default_dset = None
        self._lnd = None

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.raster_argument(name='landuse_raster', description='Input landuse raster'),
            self.string_argument(name='landuse_type', description='Landuse raster type',
                                 choices=['NLCD', 'C-CAP', 'Other'], value='NLCD'),
            self.file_argument(name='mapping_csv', description='Landuse to Chezy friction mapping table',
                               optional=True),
            self.grid_argument(name='grid', description='Target grid'),
            self.string_argument(name='default_value_option', description='Default Chezy friction option',
                                 choices=['Constant', 'Dataset'], value='Constant'),
            self.float_argument(name='default_value', description='Default Chezy friction value', value=60.0,
                                min_value=0.0),
            self.dataset_argument(name='default_value_dataset', description='Default Chezy friction dataset',
                                  optional=True, filters=[ALLOW_ONLY_SCALARS, ALLOW_ONLY_POINT_MAPPED]),
            self.dataset_argument(name='locked_nodes_dataset', description='Locked nodes dataset (optional)',
                                  optional=True, filters=[ALLOW_ONLY_SCALARS, ALLOW_ONLY_POINT_MAPPED]),
            self.dataset_argument(name='friction_dataset', description='Output Chezy friction dataset',
                                  value='ChezyFric', io_direction=IoDirection.OUTPUT),
        ]
        self.enable_arguments(arguments)
        return arguments

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}
        # Validate input data
        self._cogrid = self.get_input_grid(arguments[ARG_INPUT_GRID].text_value)
        if not self._cogrid:
            errors[arguments[ARG_INPUT_GRID].name] = 'Could not open target grid.'
        land_use_type = arguments[ARG_INPUT_LANDUSE_TYPE].text_value if \
            arguments[ARG_INPUT_LANDUSE_TYPE].text_value is not None else ''
        csv_file = arguments[ARG_INPUT_CSV].text_value if arguments[ARG_INPUT_CSV].text_value is not None \
            else ''

        # Custom mapping table CSV
        if len(csv_file) > 0:
            if not os.path.exists(csv_file):
                errors[arguments[ARG_INPUT_CSV].name] = 'Could not open Landuse to Chezy friction mapping table.'
        elif land_use_type == 'Other':  # Must specify a mapping table if using custom land use type
            errors[arguments[ARG_INPUT_CSV].name] = \
                'Must select a Chezy friction mapping table CSV file when land use type is "Other".'

        return errors

    def enable_arguments(self, arguments):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.
        """
        default_dset = arguments[ARG_INPUT_DEFAULT_OPTION].text_value == 'Dataset'
        arguments[ARG_INPUT_DEFAULT].hide = default_dset  # hide constant edit field if option is dataset
        arguments[ARG_INPUT_DEFAULT_DSET].hide = not default_dset  # hide dataset field if option is constant

        custom_codes = arguments[ARG_INPUT_LANDUSE_TYPE].text_value == 'Other'
        arguments[ARG_INPUT_CSV].hide = not custom_codes

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        # Set up some of the grid variables
        self._grid_uuid = self._cogrid.uuid
        self._grid = self._cogrid.ugrid

        # Initialize some of the raster properties, values, etc.
        self._initialize_raster(arguments)
        self._default_value = float(arguments[ARG_INPUT_DEFAULT].text_value)

        # Get the Chezy friction info
        code_vals, _, chezy_friction_vals = self.get_friction_info(arguments)

        # Calculate the size function and set the dataset values default
        friction_creator = LanduseMapper(self.logger, self._raster, self._grid, self._default_value,
                                         code_vals, chezy_friction_vals,
                                         self._default_dset, self._lnd)
        self._dataset_vals = friction_creator.process_points()

        # Write out the dataset
        self._setup_output_dataset_builder(arguments)
        self._add_output_datasets()

    def _initialize_raster(self, arguments):
        """Get the raster, its size, geo transform, and band, etc."""
        self.logger.info('Retrieving input raster...')
        self._raster = self.get_input_raster(arguments[ARG_INPUT_LANDUSE].text_value)
        self._raster_null_value = self._raster.nodata_value
        self._raster_values = self._raster.get_raster_values()

    def _setup_output_dataset_builder(self, arguments):
        """Set up dataset builders for selected tool outputs.

        Args:
            arguments (list): The tool arguments.
        """
        # Create a place for the output dataset file
        dataset_name = arguments[ARG_OUTPUT_DATASET].text_value
        self._builder = self.get_output_dataset_writer(
            name=dataset_name,
            geom_uuid=self._grid_uuid,
            null_value=self._raster_null_value,
        )

    def _add_output_datasets(self):
        """Add datasets created by the tool to be sent back to XMS."""
        self.logger.info('Adding output dataset...')
        if self._builder is not None:
            self.logger.info('Writing output Chezy friction dataset to XMDF file...')
            self._builder.write_xmdf_dataset([0.0], [self._dataset_vals])
            # Send the dataset back to XMS
            self.set_output_dataset(self._builder)

    def _get_raster_value(self, col, row):
        """Get the raster value for a pixel.

        Args:
            col (int): Column of the pixel value
            row (int): Row of the pixel value

        Returns:
            float: The raster value for the pixel or the default Chezy friction value if out of bounds
        """
        try:
            return self._raster_values[row][col]
        except Exception:
            return self._default_value

    def get_friction_info(self, arguments):
        """Gets the Chezy friction information, either built in or from a user CSV file.

        Arguments:
            arguments (list): The tool arguments.

        Returns:
            tuple of lists:  The code values, descriptions, and Chezy friction values for each landuse type.
        """
        if arguments[ARG_INPUT_LANDUSE_TYPE].text_value == 'NLCD':
            return get_nlcd_chezy_friction_values()
        elif arguments[ARG_INPUT_LANDUSE_TYPE].text_value == 'C-CAP':
            return get_ccap_chezy_friction_values()
        elif arguments[ARG_INPUT_LANDUSE_TYPE].text_value == 'Other':
            # The user chose a .csv file.  Read it as a dataframe, and rename the columns.
            mapping_df = pandas.read_csv(arguments[ARG_INPUT_CSV].text_value, index_col=0, header=0)
            mapping_df.columns = ['Description', 'Friction']
            return mapping_df.index.tolist(), mapping_df['Description'].tolist(), mapping_df['Friction'].tolist()
        return [], [], []
