"""CoverageComponent class."""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'
__all__ = ['CoverageComponent', 'get_node_groups', 'get_link_groups', 'StormDrainLinkComponent',
           'StormDrainNodeComponent']

# 1. Standard Python modules
import copy
from functools import cached_property
from pathlib import Path
from typing import Optional

# 2. Third party modules
from PySide2.QtWidgets import QWidget

# 3. Aquaveo modules
from xms.components.display.display_options_helper import MULTIPLE_TYPES, UNASSIGNED_TYPE
from xms.gmi.components2.gmi_coverage_component_base import GmiCoverageComponentBase
from xms.gmi.data.generic_model import GenericModel, Group, Section
from xms.gmi.gui.section_dialog import SectionDialog
from xms.guipy.data.target_type import TargetType

# 4. Local modules
from xms.swmm.data.coverage_data import CoverageData
from xms.swmm.data.model import get_swmm_model


class CoverageComponent(GmiCoverageComponentBase):
    """A Dynamic Model Interface (DMI) component base for the SWMM model."""

    def __init__(self, main_file: Optional[str | Path]):
        """
        Initialize the component.

        Args:
            main_file: The component's main-file.
        """
        super().__init__(main_file)
        # self._section_dialog_keyword_args |= {'enable_unchecked_groups': True, 'hide_checkboxes': True}

    @property
    def features(self) -> dict[TargetType, list[tuple[str, str]]]:
        """The features this coverage supports."""
        model = get_swmm_model()
        feature_dict = {}
        for target in [TargetType.point, TargetType.arc]:
            section = model.section_from_target_type(target)
            feature_dict[target] = []
            for group_name in section.group_names:
                label = section.group(group_name).label
                feature_dict[target].append((group_name, label))
        return feature_dict

    def _assign_feature(
            self, parent: QWidget, dialog_name: str, window_title: str, target: TargetType, feature_ids: list[int]
    ):
        """
        Display the Assign feature dialog and persist data if accepted.

        Args:
            parent: Parent widget for any dialog windows created.
            dialog_name: Suggested name for the dialog. Used to store dialog settings in the registry.
            window_title: title of window
            target: Point, arc, or polygon.
            feature_ids: IDs of selected features.
        """
        section = self._get_section(target)
        component_id = self.get_comp_id(target, feature_ids[0])
        values = self.data.feature_values(target, component_id)
        if values:
            section.restore_values(values)

        dlg = SectionDialog(
            parent=parent,
            section=section,
            is_interior=False,
            dlg_name=dialog_name,
            window_title=window_title,
            **self._section_dialog_keyword_args,
        )
        if dlg.exec():
            if len(feature_ids) > 1:
                for feature_id in feature_ids:
                    component_id = self.get_comp_id(target, feature_id)
                    values = self.data.feature_values(target, component_id)
                    temp_section = section.copy()
                    temp_section.restore_values(values)
                    group_name = ''
                    if target == TargetType.point:
                        group_name = 'junction'
                        group = temp_section.group(group_name)
                        if not group.is_active:
                            group_name = 'outfall'
                    elif target == TargetType.arc:
                        group_name = 'conduit'
                    old_name = temp_section.group(group_name).parameter('name').value
                    if not old_name:
                        old_name = f'{group_name}_{feature_id}'
                    dlg.section.group(group_name).parameter('name').value = old_name
                    values = dlg.section.extract_values()
                    active_group = dlg.section.active_group_name(UNASSIGNED_TYPE, MULTIPLE_TYPES)
                    component_id = self.data.add_feature(target, values, active_group)
                    self.assign_feature_ids(target, active_group, [feature_id], component_id)
            else:
                values = dlg.section.extract_values()
                active_group = dlg.section.active_group_name(UNASSIGNED_TYPE, MULTIPLE_TYPES)
                component_id = self.data.add_feature(target, values, active_group)
                self.assign_feature_ids(target, active_group, feature_ids, component_id)
            # This effectively drops nothing, but somehow it makes xarray realize that the values column really
            # isn't a float column and fixes a crash.
            ids = list(self.data.component_ids(target))
            self.data.drop_unused_features(target, ids)
            self.data.commit()

    def _get_section(self, target: TargetType) -> Section:
        return get_swmm_model().section_from_target_type(target)


class StormDrainLinkComponent(CoverageComponent):
    """A Dynamic Model Interface (DMI) coverage component for SWMM Storm Drain Links."""
    pass


class StormDrainNodeComponent(CoverageComponent):
    """A Dynamic Model Interface (DMI) coverage component for SWMM Storm Drain Nodes."""
    @cached_property
    def data(self) -> CoverageData:
        """The component's data manager."""
        return CoverageData(self.main_file)


def _group_from_id(coverage_component: CoverageComponent, coverage_parameters: GenericModel, target_type: int,
                   feature_id: int, group_name: str) -> Group:
    component_id = coverage_component.get_comp_id(target_type, feature_id)
    values = coverage_component.data.feature_values(target_type, component_id)
    coverage_parameters.restore_values(values)
    group = copy.deepcopy(coverage_parameters.group(group_name))
    return group


def get_node_groups(points: list, node_component: StormDrainNodeComponent, node_coverage_parameters: Section) -> dict:
    """Gets the node group, group name, and point ID info from a list of points."""
    node_dict = {}
    for point in points:
        group_name = 'junction'
        group = _group_from_id(node_component, node_coverage_parameters, TargetType.point, point.id,
                               group_name)
        if not group.is_active:
            group_name = 'outfall'
            group = _group_from_id(node_component, node_coverage_parameters, TargetType.point, point.id,
                                   group_name)
        if group.is_active:
            name = group.parameter('name').value.replace(' ', '_')
            node_dict.setdefault(name, [])
            node_dict[name].append((group, group_name, point.id))
    return node_dict


def get_link_groups(arcs: list, link_component: StormDrainLinkComponent, link_coverage_parameters: Section) -> dict:
    """Gets the link group, group name, and arc ID info from a list of arcs."""
    link_dict = {}
    for arc in arcs:
        group_name = 'conduit'
        group = _group_from_id(link_component, link_coverage_parameters, TargetType.arc, arc.id,
                               group_name)
        if group.is_active:
            name = group.parameter('name').value.replace(' ', '_')
            link_dict.setdefault(name, [])
            link_dict[name].append((group, group_name, arc.id))
    return link_dict
