"""Functions dealing with boundary conditions."""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules
import copy
from dataclasses import dataclass, field
import logging
from typing import NamedTuple, Sequence

# 2. Third party modules
from geopandas import GeoDataFrame

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.data_objects.parameters import Arc
from xms.guipy.data.target_type import TargetType

# 4. Local modules
from xms.gssha.components.bc_coverage_component import BcCoverageComponent
from xms.gssha.data import bc_generic_model
from xms.gssha.data.bc_coverage_data import ArcLinks

# Type aliases
FeatureBc = dict['Point | Arc | Polygon, Group']  # Feature -> Generic Model Group
NodeArcs = dict[int, list[Arc]]  # Node id -> list of adjacent arcs


@dataclass()
class BcData:
    """Data about the bcs.

    The group in feature_bc can be used for reading the parameter values, but not to change them.

    node_arcs and arc_links are only for stream networks. arc_links gets computed and set in project_writer.
    """
    feature_bc: FeatureBc = field(default_factory=dict)  # Feature -> Generic Model Group
    node_arcs: NodeArcs = field(default_factory=dict)  # Node id -> list of adjacent arcs
    arc_links: ArcLinks = field(default_factory=dict)  # arc.id -> link number
    coverage: GeoDataFrame = None  # Coverage GeoDataFrame

    def feature_from_id(self, feature_id: int, feature_type: str) -> NamedTuple:
        """Convert a feature ID and feature type to a NamedTuple from the defined GeoDataFrame coverage."""
        return feature_and_index_from_id_and_type(self.coverage, feature_id, feature_type)[0]


def get_bc_data(coverage: GeoDataFrame, bc_comp: BcCoverageComponent, bc_types: set[str]) -> BcData:
    """Traverses the stream arcs to get connectivity and other info.

    query.load_component_ids() must have been called before.

    Args:
        coverage: BC coverage
        bc_comp: The coverage component
        bc_types: The GMI group names (e.g. 'channel')
    """
    log = logging.getLogger('xms.gssha')
    log.info('Getting stream data...')
    bc_data = BcData(coverage=coverage)
    if 'channel' in bc_types or 'overland_flow' in bc_types:
        _get_data(coverage, bc_comp, bc_types, TargetType.arc, bc_data)
        bc_data.arc_links = bc_comp.data.get_arc_links()
    if 'overland_flow' in bc_types:
        _get_data(coverage, bc_comp, bc_types, TargetType.point, bc_data)
    return bc_data


def get_stream_data(query: Query, coverage: GeoDataFrame, bc_comp: BcCoverageComponent | None = None) -> BcData:
    """Returns the stream data."""
    if not bc_comp:
        do_comp = query.item_with_uuid(coverage.attrs['uuid'], model_name='GSSHA', unique_name='BcCoverageComponent')
        bc_comp = BcCoverageComponent(do_comp.main_file)
    query.load_component_ids(bc_comp, points=False, arcs=True, polygons=False)
    stream_data = get_bc_data(coverage, bc_comp, {'channel'})
    return stream_data


def get_stream_arcs(query: Query, coverage: GeoDataFrame, bc_comp: BcCoverageComponent) -> tuple[list[Arc], NodeArcs]:
    """Return the list of stream arcs and their connectivity info.

    Args:
        query: Object for communicating with XMS
        coverage: BC coverage
        bc_comp: The coverage component.

    Returns:
        See description.
    """
    query.load_component_ids(bc_comp, arcs=True)
    stream_data = get_bc_data(coverage, bc_comp, {'channel'})
    return stream_arcs_from_bc_data(stream_data)


def find_most_downstream_arc(stream_data: BcData) -> 'Arc | str':
    """Finds and returns the most downstream stream arc or returns an error string if there were problems.

    Args:
        stream_data: Data about the streams. Pass this or coverage.

    Returns:
         The most downstream stream arc.
    """
    # Get the stream arcs
    arcs, node_arcs = stream_arcs_from_bc_data(stream_data)

    # Check that we have stream arcs
    if not arcs:
        return 'No stream arcs found.'

    most_downstream_arcs = _most_downstream_arcs_from_bc_data(stream_data)

    # Check for no most downstream arcs
    if not most_downstream_arcs:
        return 'No stream arcs have been set as the most downstream arc.'

    # Check for too many most downstream arcs
    if len(most_downstream_arcs) > 1:
        msg = 'Multiple stream arcs have been set as the most downstream arc. There can be only one.'
        return msg

    # Check that it really could be the most downstream arc
    if not _is_end_arc(most_downstream_arcs[0], node_arcs, stream_data.coverage):
        msg = (
            'The arc set as the most downstream arc has adjacent stream arcs on both ends (and is therefore '
            'not the most downstream).'
        )
        return msg

    return most_downstream_arcs[0]


def compute_link_numbers(stream_data: BcData) -> 'ArcLinks | str':
    """Computes and returns stream link numbers, or returns an error string if there was a problem.

    Assumes arc directions are all pointing downhill.

    Args:
        stream_data: Data about the streams.

    Returns:
        Dict of arc.id -> link number or an error string.
    """
    arc_links: ArcLinks = {}
    rv = find_most_downstream_arc(stream_data)
    if isinstance(rv, str):
        return rv

    # Use a stack to ensure that downstream link numbers are higher than upstream ones
    most_downstream_arc = rv
    arcs, node_arcs = stream_arcs_from_bc_data(stream_data)
    next_link = len(arcs)
    arc_stack = [most_downstream_arc]
    while arc_stack:
        arc = arc_stack.pop()
        feature_arc = stream_data.feature_from_id(arc[0], arc[1])
        arc_links[feature_arc.id] = next_link
        next_link -= 1
        for adj_arc in node_arcs[feature_arc.start_node]:
            if adj_arc.id != feature_arc.id:
                arc_stack.append((adj_arc.id, adj_arc.geometry_types))
    return arc_links


def stream_arcs_from_bc_data(stream_data: BcData) -> tuple[list[Arc], NodeArcs]:
    """Return the list of stream arcs and their connectivity info.

    Args:
        stream_data: Data about the streams.

    Returns:
        See description.
    """
    arcs = list(stream_data.feature_bc.keys())
    return arcs, stream_data.node_arcs


def _most_downstream_arcs_from_bc_data(stream_data: BcData) -> list[Arc]:
    """Returns a list of the arcs marked as the most downstream."""
    arcs = []
    for arc, group in stream_data.feature_bc.items():
        if group.parameter('most_downstream_arc').value:
            arcs.append(arc)
    return arcs


def _is_end_arc(arc: tuple, node_arcs: NodeArcs, coverage: GeoDataFrame) -> bool:
    """Returns true if the arc has no adjacent arcs on at least one end.

    Args:
        arc: The arc.
        node_arcs: Dict of node ID -> list of arc IDs (arcs attached to each node) for arc connectivity.
        coverage: Coverage GeoDataFrame

    Returns:
        See description.
    """
    feature_arc, _ = feature_and_index_from_id_and_type(coverage, arc[0], arc[1])
    start_arc_count = len(node_arcs[feature_arc.start_node])
    end_arc_count = len(node_arcs[feature_arc.end_node])
    if start_arc_count == 1 or end_arc_count == 1:
        return True
    return False


def _features_from_type(coverage: GeoDataFrame, feature_type: TargetType):
    """Returns the features of the specified type."""
    if feature_type == TargetType.point:
        features = coverage[coverage['geometry_types'] == 'Point']
    elif feature_type == TargetType.arc:
        features = coverage[coverage['geometry_types'] == 'Arc']
    elif feature_type == TargetType.polygon:
        features = coverage[coverage['geometry_types'] == 'Polygon']
    else:
        raise ValueError(f'Unsupported feature type: "{str(feature_type)}".')
    return features


def _get_data(
    coverage: GeoDataFrame, bc_comp: BcCoverageComponent, bc_types: set[str], feature_type: TargetType, bc_data: BcData
) -> None:
    """Add all the arc bc data."""
    features = _features_from_type(coverage, feature_type)
    generic_model = bc_generic_model.create()
    gmi_section = generic_model.section_from_target_type(feature_type)  # point_parameters etc

    for feature in features.itertuples():
        comp_id = bc_comp.get_comp_id(feature_type, feature.id)
        if comp_id is None or comp_id < 0:
            continue

        # Get the data for this feature
        att_type, values = bc_comp.data.feature_type_values(feature_type, comp_id)
        if att_type not in bc_types:
            continue

        gmi_section.restore_values(values)
        group = copy.deepcopy(gmi_section.group(att_type))  # Have to do this cause generic model stuff is screwy

        # Store the data
        bc_data.feature_bc[feature.id, feature.geometry_types] = group

    if feature_type == TargetType.arc:
        bc_data.node_arcs = build_node_arcs(coverage, list(bc_data.feature_bc.keys()))


def feature_and_index_from_id_and_type(coverage: GeoDataFrame, feature_id: int,
                                       feature_type: str) -> tuple[NamedTuple, int]:
    """Returns the feature information and its index from the coverage, feature ID, and the feature type.

    Args:
        coverage: The coverage as a GeoDataFrame.
        feature_id: Dict of node ID -> list of arc IDs (arcs attached to each node) for arc connectivity.
        feature_type: Coverage GeoDataFrame

    Returns:
        tuple: The feature information and its GeoDataFrame index.
    """
    for feature in coverage.itertuples():
        if feature.id == feature_id and feature.geometry_types == feature_type:
            return feature, feature.Index


def build_node_arcs(coverage: GeoDataFrame, arcs: Sequence[tuple]) -> NodeArcs:
    """Builds the NodeArcs dict which holds arc connectivity info."""
    node_arcs: NodeArcs = {}
    for feature_id, feature_type in arcs:
        arc, _ = feature_and_index_from_id_and_type(coverage, feature_id, feature_type)
        node_arcs.setdefault(arc.start_node, []).append(arc)
        node_arcs.setdefault(arc.end_node, []).append(arc)
    return node_arcs
