"""Handles all the model global and arc data."""

__copyright__ = "(C) Copyright Aquaveo 2024"
__license__ = "All rights reserved"

# 1. Standard Python modules
from typing import Type

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.api.tree import tree_util
from xms.data_objects.parameters import Coverage, FilterLocation
from xms.gmi.data.generic_model import GenericModel
from xms.guipy.data.target_type import TargetType
from xms.srh.components.bc_component import BcComponent
from xms.srh.components.monitor_component import MonitorComponent
from xms.srh.components.srh_cov_component import SrhCoverageComponent
from xms.swmm.components.coverage_component import get_node_groups, StormDrainNodeComponent
from xms.swmm.data.model import get_swmm_model
from xms.swmm.dmi.xms_data import SwmmData
from xms.tool_core.table_definition import StringColumnType, TableDefinition

# 4. Local modules
from xms.srh_swmm_coupler.dmi.xms_data import CouplerData


def get_srh_coverage(
        coupler_data: CouplerData, model_name: str, coverage_type: str,
        component_type: Type[SrhCoverageComponent], unique_name: str
) -> tuple[Coverage, SrhCoverageComponent]:
    """Gets the SRH coverage from the coupler sim data, model name, coverage type, component type, and unique name.

    Args:
        coupler_data: a CouplerData class that was formed from a query.
        model_name: The coverage component model name.
        coverage_type: The coverage component type name.
        component_type: The coverage component class.
        unique_name: The coverage component unique name.

    Returns:
        tuple(Coverage, SrhCoverageComponent)
    """
    sim_uuid = coupler_data.srh_sim[0].uuid
    project_tree = coupler_data.query.copy_project_tree()
    sim_node = tree_util.find_tree_node_by_uuid(project_tree, sim_uuid)
    coverage_node = tree_util.descendants_of_type(
        sim_node, model_name=model_name, coverage_type=coverage_type, allow_pointers=True,
        only_first=True
    )
    coverage_uuid = coverage_node.uuid
    coverage = coupler_data.query.item_with_uuid(coverage_uuid)
    component = coupler_data.query.item_with_uuid(
        coverage_uuid, model_name=model_name, unique_name=unique_name
    )
    main_file = component.main_file
    component = component_type(main_file)
    coupler_data.query.load_component_ids(component, points=True, arcs=True, polygons=True, arc_groups=True)
    return coverage, component


def get_node_coverage(coupler_data: CouplerData):
    """Gets the SWMM node coverage from the coupler sim data.

    Args:
        coupler_data: a CouplerData class that was formed from a query.

    Returns:
        tuple(Coverage, SrhCoverageComponent)
    """
    swmm_data = SwmmData(coupler_data.query)
    swmm_data.sim_uuid = coupler_data.swmm_sim[0].uuid
    return swmm_data.node_coverage


def get_link_coverage(coupler_data: CouplerData):
    """Gets the SWMM link coverage from the coupler sim data.

    Args:
        coupler_data: a CouplerData class that was formed from a query.

    Returns:
        tuple(Coverage, SrhCoverageComponent)
    """
    swmm_data = SwmmData(coupler_data.query)
    swmm_data.sim_uuid = coupler_data.swmm_sim[0].uuid
    return swmm_data.link_coverage


class Model:
    """Gets the model information."""

    def __init__(self, query: Query) -> None:
        """Initializes the class.

        Args:
            query (Query): Interprocess communication object.
        """
        super().__init__()
        self._query = query
        self._inlets = []
        self._internal_sinks = []
        self._monitor_lines = []
        self._default_sink = None
        self._default_monitor = None
        self._coupler_data = CouplerData(query)

    @property
    def inlets(self) -> list:
        """Returns the list of _inlets from the SWMM model."""
        return self._inlets

    @property
    def internal_sinks(self) -> list:
        """Returns the list of internal sinks from the SRH model."""
        return self._internal_sinks

    @property
    def monitor_lines(self) -> list:
        """Returns the list of monitor lines from the SRH model."""
        return self._monitor_lines

    @property
    def default_sink(self):
        """Returns a default sink from the SRH model."""
        return self._default_sink

    @property
    def default_monitor(self):
        """Returns a default monitor line from the SRH model."""
        return self._default_monitor

    def get_coupled_model(self) -> GenericModel:
        """Returns the model."""
        model = GenericModel()
        self._add_global_parameters(model)
        return model

    def _get_swmm_inlets(self):
        """Returns the SWMM _inlets from the coupled simulation."""
        node_cov, node_component = get_node_coverage(self._coupler_data)
        node_component: StormDrainNodeComponent
        node_coverage_parameters = get_swmm_model().point_parameters
        points = node_cov.get_points(FilterLocation.PT_LOC_DISJOINT)
        node_dict = get_node_groups(points, node_component, node_coverage_parameters)
        for node_name, groups in list(node_dict.items()):
            for group in groups:
                if group[1] == 'outfall':
                    del node_dict[node_name]
        self._inlets = list(node_dict.keys())

    def _get_monitor_lines(self):
        mon_coverage, mon_component = self._get_coverage('SRH-2D', 'Monitor', MonitorComponent, 'Monitor_Component')
        mon_component: MonitorComponent
        arcs = mon_coverage.arcs
        self._monitor_lines = []
        for arc in arcs:
            comp_id = mon_component.get_comp_id(TargetType.arc, arc.id)
            label_exists = False
            if comp_id and comp_id > 0:
                mon_data_param = mon_component.data.monitor_arc_param_from_id(comp_id)
                if mon_data_param.label:
                    self._monitor_lines.append(mon_data_param.label)
                    label_exists = True
            if not label_exists:
                self._monitor_lines.append('Monitor')

    def _get_internal_sinks(self):
        bc_coverage, bc_component = self._get_coverage('SRH-2D', 'Boundary Conditions', BcComponent, 'Bc_Component')
        bc_component: BcComponent
        arcs = bc_coverage.arcs
        self._internal_sinks = []
        for arc in arcs:
            comp_id = bc_component.get_comp_id(TargetType.arc, arc.id)
            if comp_id and comp_id > 0:
                bc_id = bc_component.data.bc_id_from_comp_id(comp_id)
                bc_data_param = bc_component.data.bc_data_param_from_id(bc_id)
                if bc_data_param.bc_type == 'Internal sink':
                    if bc_data_param.label:
                        self._internal_sinks.append(bc_data_param.label)
                    else:
                        self._internal_sinks.append('Internal sink')

    def _get_coverage(
            self, model_name: str, coverage_type: str, component_type: Type[SrhCoverageComponent], unique_name: str
    ) -> tuple[Coverage, SrhCoverageComponent]:
        return get_srh_coverage(self._coupler_data, model_name, coverage_type, component_type, unique_name)

    def _add_global_parameters(self, model: GenericModel):
        """Adds SWMM model control parameters."""
        section = model.global_parameters
        group = section.add_group('general', 'General')
        group.add_boolean('sync_times', 'Sync times for models', True,
                          'Checking this option will sync the times for the models.')
        group.add_float('damping_factor', 'Flow damping factor', 1.0, 0.0, 1.0,
                        'A multiplication factor applied when assigning SWMM inlet flows from the SRH solution and SRH '
                        'internal sink flows from the SWMM solution.')
        # Get inlet nodes from SWMM simulation
        self._get_swmm_inlets()
        # Get internal sinks from SRH simulation
        self._get_internal_sinks()
        # Get monitor lines from SRH simulation
        self._get_monitor_lines()
        self._default_sink = self._internal_sinks[0] if self._internal_sinks else '<None>'
        self._default_monitor = self._monitor_lines[0] if self._monitor_lines else '<None>'
        definition = [
            StringColumnType(header='SWMM Inlet Node', enabled=False),
            StringColumnType(header='SRH Internal Sink', choices=self._internal_sinks, default=self._default_sink),
            StringColumnType(header='SRH Monitor Line', choices=self._monitor_lines, default=self._default_monitor)
        ]
        table_def = TableDefinition(definition, fixed_row_count=len(self._inlets))
        default = [['Inlet1', self._default_sink, self._default_monitor]]
        group.add_table(
            'srh_swmm_map',
            'Mapping of SWMM Inlet Nodes to SRH Internal Sinks',
            default=default,
            table_definition=table_def
        )
