"""BcDataDialogRunner class. A helper class for using the BcDataDialog."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os
from pathlib import Path
import shutil

# 2. Third party modules
import pandas as pd
from PySide2.QtGui import QIcon
from PySide2.QtWidgets import QDialog

# 3. Aquaveo modules
from xms.components.display.xms_display_message import XmsDisplayMessage
from xms.core.filesystem.filesystem import temp_filename
from xms.guipy.data.target_type import TargetType
from xms.guipy.dialogs import message_box
from xms.guipy.dialogs.xms_parent_dlg import get_xms_icon

# 4. Local modules
from xms.srh.data.bc_data import BcData
from xms.srh.data.par.bc_data_param import BcDataParam
from xms.srh.gui.bc_data_dialog import BcDialog


class BcDataDialogRunner:
    """Helper class to run the BcDataDialog."""
    def __init__(self, bc_component=None, sel_arc_ids=None, win_cont=None, query=None, cov_uuid=None):
        """Gets the hy8 exe location.

        Args:
            bc_component (:obj:`BcComponent`): component for bc coverage
            sel_arc_ids (:obj:`list`): ids of selected arcs
            win_cont (:obj:`QWidget`): Parent window
            query (:obj:`xms.api.dmi.Query`): Object for communicating with XMS
            cov_uuid (:obj:`str`): id of the coverage
        """
        self.dlg = None
        self.hy8_exe = ''
        self.bc_component = bc_component
        self.sel_arc_ids = sel_arc_ids
        self.win_cont = win_cont
        self.query = query
        self.query.xms_agent.set_timeout(90000)  # Large scatter can take awhile. Set to 90 seconds.
        self.sel_arc_comp_ids = set()
        self.sel_arc_bc_ids = set()
        self.bc_data_param = BcDataParam()
        self.comp_id = -1
        self.bc_id = -1
        self.new_comp_id = -1
        self.structure_arc_id_comp_id = dict()
        self.allow_structures = True
        self.multi_select_warning = False
        self.cov_uuid = cov_uuid
        self.bc_data_line_labels = []
        icon_str = get_xms_icon()
        self.icon = QIcon(icon_str) if icon_str else None

    def _get_selected_arc_comp_ids(self):
        """Gets the component ids for the selected arcs."""
        for arc_id in self.sel_arc_ids:
            comp_id = self.bc_component.get_comp_id(TargetType.arc, arc_id)
            if not comp_id:
                self.sel_arc_comp_ids.add(0)
                self.sel_arc_bc_ids.add(0)
            elif comp_id > 0:
                self.sel_arc_comp_ids.add(comp_id)
                bc_id = self.bc_component.data.bc_id_from_comp_id(comp_id)
                if bc_id > 0:
                    self.sel_arc_bc_ids.add(bc_id)

    def _check_multi_select(self):
        """Check if more than 1 component is selected and prompt the user if they want to continue."""
        if len(self.sel_arc_bc_ids) > 1:
            self.multi_select_warning = True

    def _fill_param_class(self):
        """Fill the param class with the data from the first selected component id."""
        for first_id in self.sel_arc_comp_ids:
            self.comp_id = first_id
            break
        if self.comp_id != -1:
            self.bc_id = self.bc_component.data.bc_id_from_comp_id(self.comp_id)
            self.bc_data_param = self.bc_component.data.bc_data_param_from_id(self.bc_id)
        if not self.bc_data_param.bc_data.label:
            self.bc_data_param.bc_data.label = f'bc_data_line_{self.sel_arc_ids[0]}'
        self.bc_data_param.hy8_culvert.hy8_exe = self.hy8_exe
        self.bc_data_param.hy8_culvert.hy8_input_file = self.bc_component.hy8_file
        self._fill_bc_data_lines()
        self._modify_bc_types_based_on_selection()

    def _fill_bc_data_lines(self):
        """Sets up the bc_data_lines class with appropriate labels from the arcs in the coverage."""
        line_selector_list = ['NONE'] + self.bc_data_line_labels
        bc_data_lines = self.bc_data_param.bc_data_lines
        bc_data_lines.param.upstream_line.objects = line_selector_list
        bc_data_lines.param.downstream_line.objects = line_selector_list
        if bc_data_lines.upstream_line_label not in self.bc_data_line_labels:
            bc_data_lines.upstream_line_label = 'NONE'
        else:
            bc_data_lines.upstream_line = bc_data_lines.upstream_line_label
        if bc_data_lines.downstream_line_label not in self.bc_data_line_labels:
            bc_data_lines.downstream_line_label = 'NONE'
        else:
            bc_data_lines.downstream_line = bc_data_lines.downstream_line_label

    def _get_bc_data_line_labels(self):
        """Get all labels from arcs with bc_type = 'bc_data'."""
        df = self.bc_component.data.comp_ids.to_dataframe()
        df = df[df['display'] == 'bc_data']
        bc_ids = df['bc_id'].tolist()
        label_set = set()
        for bc_id in bc_ids:
            par_data = self.bc_component.data.bc_data_param_from_id(bc_id)
            label_set.add(par_data.bc_data.label)
        self.bc_data_line_labels = list(label_set)

    def _xms_id_from_comp_id(self, comp_id):
        """Helper function to get the xms feature id from a component id.

        Args:
           comp_id (:obj:`int`): component id
        """
        ids = self.bc_component.get_xms_ids(TargetType.arc, comp_id)
        if type(ids) is list and len(ids) == 1:
            return ids[0]
        return -1

    def _modify_bc_types_based_on_selection(self):
        num_sel_arcs = len(self.sel_arc_ids)
        sorted_ids = self.sel_arc_ids.copy()
        sorted_ids.sort()
        is_structure = self.bc_data_param.bc_type in BcData.structures_list
        if is_structure and not self.multi_select_warning:
            arcs = self.bc_component.data.structure_param_from_id(self.bc_id)
            # if this is a structure then get the xms arc ids from the comp ids
            arcs.arc_id_0 = self._xms_id_from_comp_id(arcs.arc_id_0)
            arcs.arc_id_1 = self._xms_id_from_comp_id(arcs.arc_id_1)
            self.bc_data_param.arcs = arcs
            if arcs.arc_id_0 < 0 or arcs.arc_id_1 < 0:
                msg = f'Unable to find all arcs that define the structure of type: ' \
                      f'{self.bc_data_param.bc_type}.\nThis structure will be removed.'
                app_name = os.environ.get('XMS_PYTHON_APP_NAME')
                message_box.message_with_ok(
                    parent=self.win_cont, message=msg, app_name=app_name, icon='Warning', win_icon=self.icon
                )
                self.bc_data_param.bc_type = 'Wall (no-slip boundary)'
                if 1 < num_sel_arcs < 3:
                    self.bc_data_param.arcs.arc_id_0 = sorted_ids[0]
                    self.bc_data_param.arcs.arc_id_1 = sorted_ids[1]
                else:
                    self.allow_structures = False
        elif 1 < num_sel_arcs < 3:
            self.bc_data_param.arcs.arc_id_0 = sorted_ids[0]
            self.bc_data_param.arcs.arc_id_1 = sorted_ids[1]
        else:
            if not is_structure:
                self.allow_structures = False
            elif is_structure and self.multi_select_warning:
                self.bc_data_param.bc_type = 'Wall (no-slip boundary)'

    def _update_component_data(self):
        """Create a new component id if the user clicked OK on the dialo and save the data."""
        data = self.bc_data_param
        data.label = data.label.strip()

        if data.bc_type == 'Wall (no-slip boundary)' and not data.wall.extra_wall_roughness:
            comp_id = 0
        else:
            comp_data = self.bc_component.data
            comp_id = int(max(comp_data.comp_ids.to_dataframe()['id']) + 1)
            bc_id = int(max(comp_data.bc_data.to_dataframe()['id']) + 1)
            disp_category = BcData.display_list[data.bc_type_index()]

            is_structure = data.bc_type in BcData.structures_list
            if not is_structure:
                comp_data.set_bc_id_display_with_comp_id(bc_id, disp_category, comp_id)
            else:
                disp_dict = {'Upstream': 0, 'Downstream': 1, 'Inflow monitor': 2, 'Outflow monitor': 3}
                arcs = self.bc_data_param.arcs
                self.structure_arc_id_comp_id[arcs.arc_id_0] = comp_id
                self.structure_arc_id_comp_id[arcs.arc_id_1] = comp_id + 1
                arc_list = [arcs.arc_id_0, arcs.arc_id_1]
                self.check_arcs_for_structure_and_remove(arc_list)

                disp_0 = disp_category[disp_dict[arcs.arc_option_0]]
                comp_data.set_bc_id_display_with_comp_id(bc_id, disp_0, comp_id)
                disp_1 = disp_category[disp_dict[arcs.arc_option_1]]
                comp_data.set_bc_id_display_with_comp_id(bc_id, disp_1, comp_id + 1)
                # make a structure
                struct = self.bc_data_param.arcs
                struct.arc_id_0 = comp_id
                struct.arc_id_1 = comp_id + 1
                comp_data.set_structure_with_bc_id(struct, bc_id)
            comp_data.set_bc_data_with_id(data, bc_id)
            self.bc_component.data.commit()
        self.new_comp_id = comp_id

    def check_arcs_for_structure_and_remove(self, arc_ids):
        """Checks if the arcs were associated with a structure.

        If so, then we remove that bc and make sure to assign to the other
        arc in the structure to be the default bc (wall).

        Args:
            arc_ids (:obj:`list[int]`): arc ids
        """
        bc_data = self.bc_component.data
        for arc_id in arc_ids:
            comp_id = self.bc_component.get_comp_id(TargetType.arc, arc_id)
            bc_id = bc_data.bc_id_from_comp_id(comp_id)
            bc_par = bc_data.bc_data_param_from_id(bc_id)
            if bc_par.bc_type in BcData.structures_list:
                arcs = bc_data.structure_param_from_id(bc_id)
                remove_comp_ids = [arcs.arc_id_0, arcs.arc_id_1]
                for cid in remove_comp_ids:
                    if cid < 0:
                        continue
                    bc_data.set_bc_id_display_with_comp_id(0, 'wall', cid)
                    arc_ids = self.bc_component.get_xms_ids(TargetType.arc, cid)
                    for arc_id2 in arc_ids:
                        self.bc_component.update_component_id(TargetType.arc, arc_id2, 0)

    def _update_xms_comp_ids(self):
        """Update the component id on the selected arcs if the user clicked OK in the dialog."""
        if self.structure_arc_id_comp_id:
            for arc_id, comp_id in self.structure_arc_id_comp_id.items():
                self.bc_component.update_component_id(TargetType.arc, arc_id, int(comp_id))
        else:
            for arc_id in self.sel_arc_ids:
                self.check_arcs_for_structure_and_remove([arc_id])
                self.bc_component.update_component_id(TargetType.arc, arc_id, int(self.new_comp_id))

    def _update_component_display(self):
        """Update display stuff on the component if the user clicked OK in the dialog."""
        self.bc_component.update_id_files()
        self.bc_component.display_option_list.append(
            XmsDisplayMessage(file=self.bc_component.disp_opts_file, edit_uuid=self.bc_component.cov_uuid)
        )

    def _clean_up_unused_comp_ids(self):
        """Remove unused comp ids from the bc data."""
        delete_comp_ids = set()
        comp_df = self.bc_component.data.comp_ids.to_dataframe()
        comp_ids = comp_df['id'].to_list()
        for cid in comp_ids:
            if cid < 1:
                continue
            xms_ids = self.bc_component.get_xms_ids(TargetType.arc, cid)
            if type(xms_ids) is int:
                delete_comp_ids.add(cid)
        self.bc_component.data.remove_comp_ids(delete_comp_ids)

    def run_dlg(self):
        """Runs the assign BC dialog."""
        self.hy8_exe = self.query.named_executable_path('Hy8Exec')
        self.query.load_component_ids(self.bc_component, arcs=True)
        self._get_selected_arc_comp_ids()
        self._check_multi_select()
        self._get_bc_data_line_labels()
        self._fill_param_class()
        self._clean_up_unused_comp_ids()
        self._read_sediment_table()

        self.dlg = BcDialog(
            self.win_cont, self.bc_data_param, self.allow_structures, self.multi_select_warning, self.query,
            self.cov_uuid, self.sel_arc_ids
        )
        if self.dlg.exec() == QDialog.Accepted:
            self._write_sediment_table()
            self._update_component_data()
            self._update_xms_comp_ids()
            self._update_component_display()

        # Delete the id dumped by xms files.
        shutil.rmtree(os.path.join(os.path.dirname(self.bc_component.main_file), 'temp'), ignore_errors=True)

    def _write_sediment_table(self) -> None:
        """Write the sediment table to a file.

        We do this instead of saving the table in the param class because the number of columns can
        change and that is not currently supported (we would have to define a default dataframe with
        some high number of columns that would be the max). Also, the data would get duplicated each
        time the dialog is opened and closed.
        """
        if self.bc_data_param.sediment_inflow.sediment_table is not None:
            # Get the full file path using the component main_file dir but only store the file name
            file_name = self.bc_data_param.sediment_inflow.sediment_table_file_name
            comp_dir = Path(self.bc_component.main_file).parent
            if file_name == '':
                file_path = temp_filename(comp_dir, suffix='.csv')
                self.bc_data_param.sediment_inflow.sediment_table_file_name = Path(file_path).name
            else:
                file_path = comp_dir / file_name
            df = self.bc_data_param.sediment_inflow.sediment_table
            df.to_csv(file_path, sep=' ', index=False)
            self.bc_data_param.sediment_inflow.sediment_table = pd.DataFrame({'Time(hr)': [], 'Qs1(cfs)': []})

    def _read_sediment_table(self) -> None:
        """Read the sediment table file into a dataframe and store it in the param.

        See the comments in _write_sediment_table for why we do this.
        """
        file_name = self.bc_data_param.sediment_inflow.sediment_table_file_name
        if file_name != '':
            # Build the full file path using the component main_file dir
            file_path = Path(self.bc_component.main_file).parent / file_name
            df = pd.read_csv(file_path, sep=' ')
            self.bc_data_param.sediment_inflow.sediment_table = df
