"""MapToWashBcsTool class."""

__copyright__ = "(C) Copyright Aquaveo 2023"
__license__ = "All rights reserved"

# 1. Standard Python modules
from pathlib import Path

# 2. Third party modules

# 3. Aquaveo modules
from xms.coverage import coverage_data
from xms.guipy.settings import get_file_browser_directory
from xms.tool_core import ALLOW_ONLY_CELL_MAPPED, ALLOW_ONLY_SCALARS, IoDirection, Tool

# 4. Local modules
from xms.wash.tools.wash_mapper import WashMapper

ARG_INPUT_COVERAGE = 0
ARG_INPUT_RW_COVERAGE = 1
ARG_INPUT_GRID = 2
ARG_INPUT_MATERIAL_DATASET = 3
ARG_OUTPUT_FILE = 4
ARG_MAX_SAT_K = 5
ARG_MIN_RESIST_COEFF = 6
ARG_NODE_POINT_TOL = 7


class MapToWashBcsTool(Tool):
    """Tool to map a coverage to a UGrid and create a WASH123D 3bc file.
    """
    def __init__(self) -> None:
        """Initializes the class."""
        super().__init__(name='Map to WASH123D BCs')
        self._query = None
        # For testing
        # import os
        # os.environ['XMSTOOL_GUI_TESTING'] = 'YES'

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        start_dir = Path(get_file_browser_directory())
        arguments = [
            self.coverage_argument(name='input_coverage', description='Input FEMWATER coverage', optional=True),
            self.coverage_argument(name='input_rw_coverage', description='Input Relief Well coverage', optional=True),
            self.grid_argument(name='input_grid', description='Input UGrid'),
            self.dataset_argument(
                name='input_material_dataset',
                description='Input material dataset',
                io_direction=IoDirection.INPUT,
                optional=True,
                filters=[ALLOW_ONLY_SCALARS, ALLOW_ONLY_CELL_MAPPED]
            ),
            self.string_argument(
                name='output_3bc_file',
                description='Output 3bc file',
                io_direction=IoDirection.INPUT,
                optional=True,
                file=True,
                value=str(start_dir / 'model.3bc')
            ),
            self.float_argument(
                name='max_sat_k',
                description='Maximum Saturated Hydraulic Conductivity (1D Steady Flow In Well)',
                value=1e8
            ),
            self.float_argument(
                name='min_well_resist_coeff', description='Minimum Well Screen Resistance Coefficient', value=1e-10
            ),
            self.float_argument(name='node_point_tolerance', description='Node-Point Location Tolerance', value=1e-4)
        ]
        return arguments

    def set_data_handler(self, data_handler):
        """Set up query attribute if we have a XMSDataHandler."""
        super().set_data_handler(data_handler)
        if hasattr(self._data_handler, "_query"):
            self._query = self._data_handler._query

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors: dict[str, str] = {}
        self._validate_input_coverage(errors, arguments[ARG_INPUT_COVERAGE])
        self._validate_input_coverage(errors, arguments[ARG_INPUT_RW_COVERAGE])
        if arguments[ARG_INPUT_COVERAGE].value is None and arguments[ARG_INPUT_RW_COVERAGE].value is None:
            errors[arguments[ARG_INPUT_COVERAGE].name] = "At least one coverage must be specified."
            return errors
        self._validate_input_grid(errors, arguments[ARG_INPUT_GRID])
        self._validate_output_file(errors, arguments[ARG_OUTPUT_FILE])
        return errors

    def _validate_input_coverage(self, errors, argument):
        """Validate grid is specified and 3D.

        Args:
            errors (dict): Dictionary of errors keyed by argument name.
            argument (CoverageArgument): The grid argument.
        """
        if argument.text_value == '':
            return

        key = argument.name
        coverage = self.get_input_coverage(argument.text_value)
        if coverage is None:
            errors[key] = 'Could not get coverage.'
            return

    def _validate_input_grid(self, errors, argument):
        """Validate grid is specified and 3D.

        Args:
            errors (dict): Dictionary of errors keyed by argument name.
            argument (GridArgument): The grid argument.
        """
        if argument.value is None:
            return

        key = argument.name
        co_grid = self.get_input_grid(argument.text_value)
        if not co_grid:
            errors[key] = 'Could not read grid.'
            return

        if co_grid.ugrid.cell_count <= 0:
            errors[key] = 'Grid has no cells.'

        if not co_grid.check_all_cells_3d():
            errors[key] = 'Grid cells must all be 3D.'

    def _validate_output_file(self, errors, argument):
        """Validate that we can write to the file.

        Args:
            errors (dict): Dictionary of errors keyed by argument name.
            argument (StringArgument): The string argument.
        """
        if argument.value is None:
            return

        key = argument.name
        output_str = argument.text_value
        if output_str == '':
            errors[key] = 'You must specify an output filename.'
            return

        filepath = Path(output_str)
        if not self._is_filepath_writable(filepath):
            errors[key] = f'Cannot write to {str(filepath)}.'

    def _is_filepath_writable(self, filepath):
        """Returns True if we can write to filepath."""
        writable = True
        try:
            with filepath.open('w') as _:
                pass
        except IOError:
            writable = False
        return writable

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list[]): The tool arguments.
        """
        # Get stuff from arguments
        att_table_coverage = None
        if arguments[ARG_INPUT_COVERAGE].text_value:
            coverage_file = self.get_input_coverage_file(arguments[ARG_INPUT_COVERAGE].value)
            att_table_coverage = coverage_data.get_coverage_data(coverage_file)

        bc_coverage = None
        if arguments[ARG_INPUT_RW_COVERAGE].text_value:
            bc_coverage = self.get_input_coverage(arguments[ARG_INPUT_RW_COVERAGE].value)

        co_grid_str = arguments[ARG_INPUT_GRID].text_value
        co_grid = self.get_input_grid(co_grid_str)
        material_dataset = self.get_input_dataset(arguments[ARG_INPUT_MATERIAL_DATASET].value)
        filepath = Path(arguments[ARG_OUTPUT_FILE].text_value)
        max_sat_k = arguments[ARG_MAX_SAT_K].value
        min_resist_coeff = arguments[ARG_MIN_RESIST_COEFF].value
        node_point_tolerance = arguments[ARG_NODE_POINT_TOL].value

        mapper = WashMapper(
            att_table_coverage,
            bc_coverage,
            co_grid,
            material_dataset,
            filepath,
            max_sat_k,
            min_resist_coeff,
            node_point_tolerance,
            self.logger,
            query=self._query
        )
        mapper.map()
