"""SrhInternalSinkFlowsFromSwmmInletPipeFlows class."""

__copyright__ = "(C) Copyright Aquaveo 2024"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os

# 2. Third party modules
import pandas as pd
from pyswmm import Nodes, Simulation

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.guipy.data.target_type import TargetType
from xms.srh.components.bc_component import BcComponent
from xms.swmm.data.model import get_swmm_model
from xms.swmm.dmi.xms_data import SwmmData

# 4. Local modules
from xms.srh_swmm_coupler.data.model import get_srh_coverage
from xms.srh_swmm_coupler.dmi.xms_data import CouplerData
from xms.srh_swmm_coupler.gui.difference_viewer import DifferenceViewer


def make_comparison(current_inflows: pd.DataFrame, new_inflows: pd.DataFrame, diff_results: dict, object_label,
                    time_label: str, time_str: str):
    """Makes a comparison between the two given dataframes.

    Args:
        current_inflows: The current inflows.
        new_inflows: New inflows for comparison to current inflows.
        diff_results: The dictionary where the resulting difference is stored.
        object_label: Label of the object for the data being compared.
        time_label: The time label used in the output.
        time_str: The time string used for identifying va;ies from the current inflows.
    """
    try:
        df_compare = current_inflows.compare(new_inflows)
        compare_dict = {}
        if len(df_compare) != 0:
            index_list = df_compare.index.tolist()
            compare_dict[time_label] = current_inflows.loc[index_list, time_str].tolist()
            for col in df_compare.columns:
                description = col[1]
                if description == 'self':
                    description = 'existing'
                if description == 'other':
                    description = 'new'
                compare_dict[f'{col[0]}, {description}'] = df_compare[col].tolist()
            for i in range(1, len(df_compare.columns), 2):
                compare_dict['Delta'] = (df_compare[df_compare.columns[i]] - df_compare[
                    df_compare.columns[i - 1]]).tolist()
        df_output = pd.DataFrame(compare_dict)
        diff_results[object_label] = df_output
    except ValueError:
        diff_results[object_label] = None


class SrhInternalSinkFlowsFromSwmmInletPipeFlows:
    """SrhInternalSinkFlowsFromSwmmInletPipeFlows class."""

    def __init__(self, query: Query, table: list, damping_factor: float):
        """Initializes the class.

        Args:
            query (Query): Interprocess communication object.
            table (list): Table linking each inlet to a sink BC arc and a monitor line arc from SRH
            damping_factor (float): A multiplication factor for assigning SWMM inlet flows to SRH internal sink flows.
        """
        self._coupler_data = CouplerData(query)
        self._bc_cov, self._bc_component = get_srh_coverage(self._coupler_data, 'SRH-2D', 'Boundary Conditions',
                                                            BcComponent, 'Bc_Component')
        self._table = table
        swmm_data = SwmmData(query)
        sim_data = swmm_data.sim_data
        self._global_parameters = get_swmm_model().global_parameters
        self._global_parameters.restore_values(sim_data.global_values)
        self._damping_factor = damping_factor

    def run(self):
        """Runs the tool."""
        # get the path to the SWMM output file
        swmm_sim_tree_item = self._coupler_data.swmm_sim[0]
        swmm_tree_item_name = swmm_sim_tree_item.name
        srh_sim_tree_item = self._coupler_data.srh_sim[0]
        srh_tree_item_name = srh_sim_tree_item.name
        proj_file_name = os.environ.get('XMS_PYTHON_APP_PROJECT_PATH')
        proj_name = os.path.splitext(os.path.basename(proj_file_name))[0]
        proj_path = os.path.dirname(proj_file_name)
        sim_dir = os.path.join(proj_path, f'{proj_name}_models', 'SWMM', swmm_tree_item_name)
        if not os.path.isdir(sim_dir):
            sim_dir = os.path.join(proj_path, f'{proj_name}', 'SWMM', swmm_tree_item_name)

        swmm_inp_file = os.path.join(sim_dir, 'swmm.inp')
        group = self._global_parameters.group('time_steps')
        time_step = group.parameter('routing_time_step').value
        sim_nodes = []
        sim_times = []
        ground_inflows = {}
        sink_bc_name_to_inlet_name = {}
        with Simulation(swmm_inp_file) as sim:
            for row in self._table:
                sim_nodes.append(Nodes(sim)[row[0]])
                ground_inflows[row[0]] = []
                sink_bc_name_to_inlet_name[row[1]] = row[0]
            first = False
            for step in sim:
                sim_time = (step.getCurrentSimulationTime() - sim.start_time).total_seconds() / 3600.0
                if first:
                    if sim_time != 0.0:
                        # Add an initial time of 0.0 since the simulation time for the first time step is not 0.0
                        sim_times.append(0.0)
                    else:
                        first = False
                sim_times.append(sim_time)
                for sim_node in sim_nodes:
                    pipe_inflow = sim_node.total_inflow - sim_node.lateral_inflow
                    ground_inflow = sim_node.total_outflow - pipe_inflow
                    if first:
                        # If necessary, assign an initial ground inflow at time 0.0 for each simulation node
                        ground_inflows[sim_node.nodeid].append(ground_inflow)
                    ground_inflows[sim_node.nodeid].append(ground_inflow)
                first = False
                sim.step_advance(int(time_step))
            diff_results = {}
            arcs = self._bc_cov.arcs
            flow_key = 'vol_per_sec'
            for arc in arcs:
                comp_id = self._bc_component.get_comp_id(TargetType.arc, arc.id)
                if comp_id and comp_id > 0:
                    bc_id = self._bc_component.data.bc_id_from_comp_id(comp_id)
                    bc_data_param = self._bc_component.data.bc_data_param_from_id(bc_id)
                    if bc_data_param.bc_type == 'Internal sink':
                        sink_label = 'Internal sink'
                        if bc_data_param.label:
                            sink_label = bc_data_param.label
                        if sink_label in sink_bc_name_to_inlet_name:
                            inlet_name = sink_bc_name_to_inlet_name[sink_label]
                            if inlet_name in ground_inflows:
                                current_inflows = bc_data_param.internal_sink.time_series_q.copy()
                                current_inflows.reset_index(drop=True, inplace=True)
                                swmm_inflows = pd.DataFrame(
                                    {'hrs': sim_times, flow_key: ground_inflows[inlet_name]})
                                new_inflows = swmm_inflows.copy()
                                if bc_data_param.internal_sink.sink_flow_type != 'Time series':
                                    bc_data_param.internal_sink.sink_flow_type = 'Time series'
                                srh_flows = current_inflows[flow_key]
                                swmm_flows = swmm_inflows[flow_key]
                                if len(srh_flows) == len(swmm_flows):
                                    new_inflows[flow_key] = srh_flows + (swmm_flows - srh_flows) * self._damping_factor
                                bc_data_param.internal_sink.time_series_q = new_inflows.copy()
                                self._bc_component.data.set_bc_data_with_id(bc_data_param, bc_id)
                                # Make comparison between existing and simulated ground inflows
                                make_comparison(current_inflows, new_inflows, diff_results, sink_label, 'Time (hours)',
                                                'hrs')
            if diff_results:
                dlg = DifferenceViewer(diff_results, f'{srh_tree_item_name} Internal Sink BC Change Info',
                                       'Internal Sink')
                dlg.exec()
            self._bc_component.data.commit()
