"""CleanDamsTool class."""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules
from pathlib import Path
import shutil
from typing import Type
from unittest.mock import MagicMock

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.api.tree import tree_util, TreeNode
from xms.core.filesystem import filesystem
from xms.datasets.dataset_reader import DatasetReader
from xms.datasets.dataset_writer import DatasetWriter
from xms.tool.utilities import xms_utils
from xms.tool_core import (
    ALLOW_ONLY_CELL_MAPPED, ALLOW_ONLY_COVERAGE_TYPE, ALLOW_ONLY_MODEL_NAME, ALLOW_ONLY_SCALARS, Argument, IoDirection,
    Tool
)

# 4. Local modules
from xms.gssha.components import dmi_util
from xms.gssha.components.sim_component import SimComponent
from xms.gssha.data import data_util, sim_generic_model
from xms.gssha.file_io import io_util
from xms.gssha.file_io.io_util import GrassData
from xms.gssha.misc.type_aliases import IntArray
from xms.gssha.tools import tool_util
from xms.gssha.tools.algorithms import clean_dams
from xms.gssha.tools.algorithms.clean_dams import ArgVals

# Constants
ARG_INPUT_DATASET = 0
ARG_INPUT_COVERAGE = 1
ARG_INPUT_FLAT_AREAS = 2
ARG_INPUT_DIGITAL_DAMS = 3
ARG_INPUT_FILL_LEFTOVERS = 4
ARG_INPUT_NEW_DATASET_NAME = 5
ARG_OUTPUT_DATASET = 6


class CleanDamsTool(Tool):
    r"""Tool to create an elevation dataset with no pits."""
    def __init__(self) -> None:
        """Initializes the class."""
        super().__init__(name='Clean Dams')
        self._dataset_reader: 'DatasetReader | None' = None  # The input dataset
        self._dataset_writer: 'DatasetWriter | None' = None  # The output dataset (saved for testing purposes)
        self._query: 'Query | None' = None

        # For testing
        # import os
        # os.environ['XMSTOOL_GUI_TESTING'] = 'YES'

    def initial_arguments(self) -> list[Argument]:
        """Get initial arguments for tool.

        Returns:
            (list): A list of the initial tool arguments.
        """
        dataset_path = self._default_dataset_path()
        default_cov_path = self._default_coverage_path()
        dataset_filters = [ALLOW_ONLY_SCALARS, ALLOW_ONLY_CELL_MAPPED]
        cov_filters = {ALLOW_ONLY_MODEL_NAME: 'GSSHA', ALLOW_ONLY_COVERAGE_TYPE: 'Boundary Conditions'}
        cov_desc = 'Input GSSHA Boundary Conditions coverage'

        arguments = [
            self.dataset_argument('input_dataset', 'Input dataset', value=dataset_path, filters=dataset_filters),
            self.coverage_argument('input_cov', cov_desc, value=default_cov_path, filters=cov_filters, optional=True),
            self.bool_argument('input_flat_areas', 'Clean flat areas', value=True),
            self.bool_argument('input_digital_dams', 'Clean digital dams', value=True),
            self.bool_argument('input_fill_leftovers', 'Fill leftover problems', value=True),
            self.string_argument(name='output_dataset_name', description='Output dataset name', optional=True),
            self.dataset_argument(name='output_dataset', hide=True, optional=True, io_direction=IoDirection.OUTPUT)
        ]
        return arguments

    def _default_dataset_path(self) -> str:
        """Returns the default dataset path, or ''."""
        dataset_node = _default_elevation_dataset_node(self._query)
        dataset_path = tool_util.argument_path_from_node(dataset_node)
        return dataset_path

    def _default_coverage_path(self) -> str:
        """Returns the default coverage path, or ''."""
        bc_coverage_node = dmi_util.get_default_bc_coverage_node(self._query.project_tree)
        default_cov_path = tool_util.argument_path_from_node(bc_coverage_node)
        return default_cov_path

    def set_data_handler(self, data_handler):
        """Set up query attribute if we have a XMSDataHandler."""
        super().set_data_handler(data_handler)
        if hasattr(self._data_handler, "_query"):
            self._query = self._data_handler._query

    def validate_arguments(self, arguments: list[Type[Argument]]) -> dict[str, str]:
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors: dict[str, str] = {}
        if not self._validate_dataset(arguments[ARG_INPUT_DATASET], errors):
            return errors
        return errors

    def _validate_dataset(self, argument: Type[Argument], errors: dict[str, str]) -> bool:
        """Validates the dataset argument."""
        self._dataset_reader = self._validate_input_dataset(argument, errors)

        # Make sure it's a UGrid dataset
        ug_txt = argument.text_value
        if xms_utils.tool_is_running_from_xms(self):  # TODO we would like a better way to do this kind of check
            if not ug_txt.startswith('UGrid Data'):
                errors[argument.name] = 'Input dataset must be associated with a UGrid'
                return False
        return True

    def _get_argument_values(self, arguments: list[Argument]) -> ArgVals:
        """Gets the tool arguments that we need."""
        arg_vals = ArgVals()
        arg_vals.coverage = self.get_input_coverage(arguments[ARG_INPUT_COVERAGE].text_value)
        arg_vals.clean_flat_areas = arguments[ARG_INPUT_FLAT_AREAS].value
        arg_vals.clean_digital_dams = arguments[ARG_INPUT_DIGITAL_DAMS].value
        arg_vals.fill_leftover_problems = arguments[ARG_INPUT_FILL_LEFTOVERS].value
        arg_vals.new_dataset_name = arguments[ARG_INPUT_NEW_DATASET_NAME].value
        arg_vals.co_grid = self.get_input_dataset_grid(arguments[ARG_INPUT_DATASET].text_value)
        arg_vals.ugrid = arg_vals.co_grid.ugrid  # So we do this only once
        return arg_vals

    def run(self, arguments: list[Type[Argument]]) -> None:
        """Override to run the tool.

        Args:
            arguments: The tool arguments.
        """
        # Run cleandam
        arg_vals = self._get_argument_values(arguments)
        elev_values = self._dataset_reader.values[0]
        out_file_path = clean_dams.run(elev_values, arg_vals, self._query, self.logger)

        # Read output, create a dataset, delete temp dir
        grass_data = io_util.read_grass_file(out_file_path)
        orig_name = self._dataset_reader.name
        on_off_cells = data_util.get_on_off_cells(arg_vals.co_grid, arg_vals.ugrid)
        dataset_writer = _dataset_from_grass_data(orig_name, arg_vals, grass_data, on_off_cells)
        _delete_temp_dir(out_file_path.parent, _make_temp_dir)

        # Set output
        self._dataset_writer = dataset_writer  # Save it for testing purposes
        self.set_output_dataset(self._dataset_writer)


def _make_temp_dir() -> Path:
    """Makes a temporary directory."""
    temp_dir = Path(filesystem.temp_filename())
    filesystem.make_or_clear_dir(temp_dir)
    return temp_dir


def _delete_temp_dir(temp_dir: Path, make_temp_dir_method) -> None:
    """Deletes the temporary directory unless we're testing."""
    if not isinstance(make_temp_dir_method, MagicMock):  # Leave this if we're testing
        shutil.rmtree(temp_dir)  # pragma no cover - If we're testing, we won't hit here


def _dataset_from_grass_data(
    orig_name: str, arg_vals: ArgVals, grass_data: GrassData, on_off_cells: IntArray
) -> DatasetWriter:
    """Creates a dataset from the grass file data."""
    name = arg_vals.new_dataset_name if arg_vals.new_dataset_name else f'{orig_name}-cleaned'
    return io_util.dataset_from_grass_data(name, arg_vals.co_grid.uuid, grass_data, on_off_cells)


def _default_elevation_dataset_node(query: Query) -> 'TreeNode | None':
    """Returns the default elevation dataset tree node, or None."""
    sim_node = dmi_util.get_active_sim_node(query.project_tree)
    if not sim_node:
        return None

    main_file = query.item_with_uuid(item_uuid=sim_node.uuid, model_name='GSSHA', unique_name='Sim_Manager').main_file
    sim_comp = SimComponent(main_file)
    gm = sim_generic_model.create(default_values=False)
    gm.global_parameters.restore_values(sim_comp.data.global_values)
    dataset_uuid = gm.global_parameters.group('overland_flow').parameter('ELEVATION').value
    return tree_util.find_tree_node_by_uuid(query.project_tree, dataset_uuid)
