"""ReorientStreamArcsTool class."""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules

# 2. Third party modules

# 3. Aquaveo modules
from xms.tool_core import ALLOW_ONLY_COVERAGE_TYPE, ALLOW_ONLY_MODEL_NAME, Argument, IoDirection, Tool

# 4. Local modules
from xms.gssha.components import dmi_util
from xms.gssha.data import bc_util
from xms.gssha.tools import tool_util
from xms.gssha.tools.algorithms import stream_orienter

# Constants
ARG_INPUT_COVERAGE = 0
ARG_INPUT_COVERAGE_NAME = 1
ARG_OUTPUT_COVERAGE = 2


class ReorientStreamArcsTool(Tool):
    """Tool to create a new coverage in which the stream arc directions are such that they point downhill.
    """
    def __init__(self) -> None:
        """Initializes the class."""
        super().__init__(name='Reorient Stream Arcs')
        self._query = None

        # For testing
        # import os
        # os.environ['XMSTOOL_GUI_TESTING'] = 'YES'

    def initial_arguments(self) -> list[Argument]:
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        cov_filters = {ALLOW_ONLY_MODEL_NAME: 'GSSHA', ALLOW_ONLY_COVERAGE_TYPE: 'Boundary Conditions'}
        bc_coverage_node = dmi_util.get_default_bc_coverage_node(self._query.project_tree)
        default_cov_path = tool_util.argument_path_from_node(bc_coverage_node)
        cov_desc = 'Input GSSHA Boundary Conditions coverage'
        arguments = [
            self.coverage_argument(name='input_cov', description=cov_desc, value=default_cov_path, filters=cov_filters),
            self.string_argument(name='output_cov_name', description='Output coverage name', optional=True),
            self.coverage_argument(name='output_cov', hide=True, optional=True, io_direction=IoDirection.OUTPUT)
        ]
        return arguments

    def set_data_handler(self, data_handler) -> None:
        """Set up query attribute if we have a XMSDataHandler."""
        super().set_data_handler(data_handler)
        if hasattr(self._data_handler, "_query"):
            self._query = self._data_handler._query

    def validate_arguments(self, arguments: list[Argument]) -> dict[str, str]:
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors: dict[str, str] = {}

        # Make sure coverage has one and only one downstream arc
        coverage = self.get_input_coverage(arguments[ARG_INPUT_COVERAGE].value)
        stream_data = bc_util.get_stream_data(self._query, coverage)
        rv = bc_util.find_most_downstream_arc(stream_data)
        if isinstance(rv, str):
            errors[arguments[ARG_INPUT_COVERAGE].name] = rv
            return errors

        return errors

    def run(self, arguments: list[Argument]) -> None:
        """Override to run the tool.

        Args:
            arguments: The tool arguments.
        """
        # Get input argument values
        coverage = self.get_input_coverage(arguments[ARG_INPUT_COVERAGE].value)
        coverage_name = arguments[ARG_INPUT_COVERAGE_NAME].value

        # Build the coverage
        new_name = coverage_name if coverage_name else coverage.attrs['name']
        stream_orienter.run(self._query, coverage, new_name, self.logger)

        # Don't need to call self.set_output_coverage() because coverage was added by stream_orienter.run()
