"""StreamOrienter class."""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules
import logging
from pathlib import Path
import shutil

# 2. Third party modules
from geopandas import GeoDataFrame
from shapely import LineString

# 3. Aquaveo modules
from xms.api.dmi import Query

# 4. Local modules
from xms.gssha.components import dmi_util, gmi_util
from xms.gssha.components.bc_coverage_component import BcCoverageComponent
from xms.gssha.data import bc_util
from xms.gssha.data.bc_util import BcData, feature_and_index_from_id_and_type, NodeArcs


def run(query: Query, bc_cov: GeoDataFrame, new_name: str, logger: logging.Logger | None) -> GeoDataFrame | None:
    """Runs the StreamOrienter.

    Pass win_cont for interactive pop-up dialogs with user, or logger to avoid pop-up dialogs.

    Args:
        query: Object for communicating with XMS
        bc_cov: BC coverage
        new_name: Name for the new coverage
        logger: The logger.
    """
    orienter = StreamOrienter(query, bc_cov, new_name, logger)
    return orienter.run()


class StreamOrienter:
    """Orients stream arcs to point downstream and assigns appropriate link numbers."""
    def __init__(self, query: Query, bc_cov: GeoDataFrame, new_name: str, logger: logging.Logger | None) -> None:
        """Initializer.

        Pass win_cont for interactive pop-up dialogs with user, or logger to avoid pop-up dialogs.

        Args:
            query: Object for communicating with XMS
            bc_cov: BC coverage
            new_name: Name for the new coverage
            logger: The logger. If not None, any error messages are sent here instead.
        """
        self._query = query
        self._bc_cov = bc_cov
        self._new_name = new_name
        self._logger = logger

        self._bc_comp: BcCoverageComponent = dmi_util.get_bc_coverage_component(
            self._bc_cov.attrs['uuid'], self._query
        ) if (self._bc_cov is not None) else None
        self._new_bc_cov: GeoDataFrame | None = None
        self._new_bc_comp: BcCoverageComponent | None = None
        self._stream_data: BcData | None = None
        self._most_downstream_arc: tuple | None = None
        self._node_arcs: NodeArcs | None = None
        self._changes_made: bool = False

    def run(self) -> GeoDataFrame | None:
        """Orients stream arcs to point downstream and assigns appropriate link numbers."""
        try:
            self._copy_coverage_and_component()
            self._get_stream_data()
            self._get_most_downstream_arc()
            self._reorient_arcs()
            self._add_to_query_if_changes()
        except RuntimeError:
            pass
        return self._new_bc_cov

    def _copy_coverage_and_component(self) -> None:
        """Copies the coverage and the coverage component."""
        self._logger.info('Copying the coverage...')
        rv = gmi_util.copy_coverage_and_component(self._bc_cov, self._bc_comp, self._query)
        self._new_bc_cov, self._new_bc_comp = rv[0], rv[1]
        self._new_bc_cov.attrs['name'] = self._new_name

    def _get_stream_data(self) -> None:
        """Gets the stream data."""
        self._logger.info('Getting the stream data...')
        self._stream_data = bc_util.get_stream_data(self._query, self._new_bc_cov, self._new_bc_comp)
        _arcs, self._node_arcs = bc_util.stream_arcs_from_bc_data(self._stream_data)

    def _get_most_downstream_arc(self) -> None:
        """Gets the most downstream arc."""
        self._most_downstream_arc = bc_util.find_most_downstream_arc(self._stream_data)
        if isinstance(self._most_downstream_arc, str):
            self._logger.warning(self._most_downstream_arc)
            raise RuntimeError()

    def _reorient_arcs(self) -> None:
        """Reorients the arcs."""
        self._logger.info('Reorienting arcs...')
        self._changes_made = _reorient_arcs(self._bc_cov, self._most_downstream_arc, self._node_arcs)

    def _add_to_query_if_changes(self) -> None:
        """Add new coverage and component to the query if we made changes, or warn and cleanup if we didn't."""
        if self._changes_made:
            # Add to query
            self._logger.info('Adding coverage...')
            unique_name = 'BcCoverageComponent'
            coverage_type = 'Boundary Conditions'
            gmi_util.add_to_query(self._new_bc_cov, self._new_bc_comp, unique_name, coverage_type, 'GSSHA', self._query)
        else:
            self._logger.warning('All stream arc directions are already pointing downhill. No changes made.')
            if self._new_bc_comp.main_file and Path(self._new_bc_comp.main_file).exists():
                shutil.rmtree(Path(self._new_bc_comp.main_file).parent)  # Delete the new cov comp directory
            self._new_bc_cov = None


def _reorient_arcs(coverage: GeoDataFrame, most_downstream_arc: tuple, node_arcs: NodeArcs) -> bool:
    """Reorients the arcs so they point downstream.

    Args:
        coverage: The coverage.
        most_downstream_arc: The most downstream arc info.
        node_arcs: Dict of node ID -> list of arc IDs (arcs attached to each node) for arc connectivity.
    """
    changes_made = False

    # Reorient the arc if necessary

    feature_arc, feature_index = feature_and_index_from_id_and_type(
        coverage, most_downstream_arc[0], most_downstream_arc[1]
    )
    start_arc_count = len(node_arcs[feature_arc.start_node])
    end_arc_count = len(node_arcs[feature_arc.end_node])
    if start_arc_count == 1 and end_arc_count > 1:
        start_node = feature_arc.start_node
        end_node = feature_arc.end_node
        coverage.at[feature_index, 'start_node'] = end_node
        coverage.at[feature_index, 'geometry'] = LineString(coverage.at[feature_index, 'geometry'].coords[::-1])
        coverage.at[feature_index, 'end_node'] = start_node
        changes_made = True

    # Reorient the arcs upstream from the selected arc
    stack = [most_downstream_arc]
    while stack:
        downstream_arc = stack.pop()
        down_feature_arc, feature_index = feature_and_index_from_id_and_type(
            coverage, downstream_arc[0], downstream_arc[1]
        )
        for adj_arc in node_arcs[down_feature_arc.start_node]:
            if (adj_arc.id, adj_arc.geometry_types) == downstream_arc:
                continue

            if adj_arc.end_node != down_feature_arc.start_node:
                start_node = adj_arc.start_node
                end_node = adj_arc.end_node
                coverage.at[adj_arc.Index, 'start_node'] = end_node
                coverage.at[adj_arc.Index, 'geometry'] = LineString(coverage.at[adj_arc.Index, 'geometry'].coords[::-1])
                coverage.at[adj_arc.Index, 'end_node'] = start_node
                changes_made = True
            stack.append((adj_arc.id, adj_arc.geometry_types))
    return changes_made
