"""NewSimDialog class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules
from PySide2.QtCore import QEvent, QPoint, Qt
from PySide2.QtWidgets import QListWidgetItem, QSizePolicy, QToolTip

# 3. Aquaveo modules
from xms.api._xmsapi.dmi import UGridItem
from xms.api.tree import tree_util, TreeNode
from xms.guipy.dialogs import message_box
from xms.guipy.dialogs.xms_parent_dlg import XmsDlg
from xms.guipy.widgets import widget_builder
from xms.guipy.widgets.qx_project_explorer_selector_widget import QxProjectExplorerSelectorWidget

# 4. Local modules
from xms.mf6.gui import gui_util
from xms.mf6.gui.new_sim_dialog_ui import Ui_NewSimDialog
from xms.mf6.misc import util


class NewSimDialog(XmsDlg):
    """A dialog that appears when creating a new simulation."""
    _model_idxs = {'GWF': 0, 'GWT': 1, 'GWE': 2, 'PRT': 3, '': 4}
    _model_pnames = ['flow', 'trans', 'energy', 'track', 'sim']
    _window_titles = [
        'New Groundwater Flow (GWF) Model', 'New Groundwater Transport (GWT) Model',
        'New Groundwater Energy Transport (GWE) Model', 'New Particle Tracking (PRT) Model', 'New MODFLOW 6 Simulation'
    ]

    def __init__(self, project_tree, new_uuid: str, parent=None, model_str='', help_id: str = ''):
        """Initializes the class, sets up the ui.

        Args:
            project_tree: Project Explorer tree.
            new_uuid: Uuid of the new item being created.
            parent (Something derived from QWidget): The parent window.
            model_str (str): If coming from Add Package > GWF (or GWT), is 'GWF' or 'GWT' respectively
            help_id: The second part of the wiki help line on the above page (after the '|').
        """
        super().__init__(parent, f'xms.mf6.gui.new_sim_dialog-{model_str}')

        self._project_tree = project_tree
        self._new_uuid = new_uuid
        self._model_str = model_str

        self.help_getter = gui_util.help_getter(help_id)
        self.setWindowTitle('New MODFLOW 6 Simulation')

        self.ui = Ui_NewSimDialog()
        self.ui.setupUi(self)

        self._setup_ugrid_tree_widget(project_tree)
        self._setup_models_list_widget()
        self._setup_stacked_models_widget()
        self.ui.edt_simulation_name.installEventFilter(self)

        model_idx = NewSimDialog._model_idxs[self._model_str]
        pname = NewSimDialog._model_pnames[model_idx]
        new_node = tree_util.find_tree_node_by_uuid(self._project_tree, self._new_uuid)
        name = _find_good_tree_name(new_node, pname)
        self.ui.edt_simulation_name.setText(name)

        if self._model_str:  # Doing new model, not new sim
            self.ui.txt_simulation_name.setText('Model name:')
            self.ui.txt_auto_sim_packages.setVisible(False)
        self.setWindowTitle(NewSimDialog._window_titles[model_idx])

        if not util.PRT_INTERFACE:
            prt_page = self.ui.stacked_models.widget(NewSimDialog._model_idxs['PRT'])
            self.ui.stacked_models.removeWidget(prt_page)

        # Signals
        self.ui.lst_models.currentItemChanged.connect(self._on_lst_models_changed)
        self.ui.lst_models.itemChanged.connect(self._on_lst_models_item_change)

        self._set_current_model()
        self.adjustSize()

        self.ui.buttonBox.helpRequested.connect(self.help_requested)

        self.data = {}  # Stores the dialog state for use after dialog is closed on OK.

    def _setup_ugrid_tree_widget(self, project_tree: TreeNode) -> None:
        """Set up the UGrid tree widget.

        Args:
            project_tree: Project Explorer tree.
        """
        # If there's only one ugrid, select it by default
        items = tree_util.descendants_of_type(project_tree, UGridItem)
        previous_selection = ''
        if items and len(items) == 1:
            previous_selection = items[0].uuid

        # Swap out the generic QTreeWidget for the QxProjectExplorerSelectorWidget
        self.ui_tree_ugrids = QxProjectExplorerSelectorWidget(
            root_node=project_tree, selectable_item_type=UGridItem, parent=self, previous_selection=previous_selection
        )
        widget_builder.replace_widgets(self.ui.vlay_left.layout(), self.ui.tree_ugrids, self.ui_tree_ugrids)
        # self.ui_tree_ugrids.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.Minimum)
        self.ui_tree_ugrids.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Minimum)
        # self.ui_tree_ugrids.setMinimumWidth(250)

    def _setup_models_list_widget(self) -> None:
        """Set up the models list widget by making the items checkable."""
        # Hide the models list if the model is being provided
        if self._model_str:  # Doing new model, not new sim
            self.ui.txt_models.setVisible(False)
            self.ui.lst_models.setVisible(False)
            return

        # Make the list items checkable. The items are defined in the .ui file
        for i in range(self.ui.lst_models.count()):
            list_item = self.ui.lst_models.item(i)
            list_item.setFlags(list_item.flags() | Qt.ItemIsUserCheckable)
            list_item.setCheckState(Qt.Unchecked)

    def _setup_stacked_models_widget(self) -> None:
        """Set up the stacked models widget."""
        # If doing new model, remove all the other pages so dialog can be shrunk (GWF is long). The list of models does
        # not change.
        if self._model_str:  # Doing new model, not new sim
            model_idx = NewSimDialog._model_idxs[self._model_str]
            # Iterate backwards since we're removing things
            for i in range(self.ui.stacked_models.count() - 1, -1, -1):
                if i != model_idx:
                    model_page = self.ui.stacked_models.widget(i)
                    self.ui.stacked_models.removeWidget(model_page)

        # If doing new sim, disable all the model pages to start out
        else:
            for i in range(self.ui.stacked_models.count()):
                self.ui.stacked_models.widget(i).setEnabled(False)

    def _set_current_model(self) -> None:
        """Set the current model list item.

        Should be called after setting up the signals so the stacked widget gets updated.
        """
        if self._model_str:  # Doing new model, not new sim
            model_idx = NewSimDialog._model_idxs[self._model_str]
        else:
            model_idx = NewSimDialog._model_idxs['GWF']  # Turn on GWF by default for new simulations
        self.ui.lst_models.setCurrentRow(model_idx)
        self.ui.lst_models.item(model_idx).setCheckState(Qt.Checked)

    def _on_lst_models_changed(self, current_item, previous_item) -> None:
        """Called when the current item in the list is changed.

        Args:
            current_item: The current item.
            previous_item: The previous item.
        """
        current_index = self.ui.lst_models.currentRow()
        self.ui.stacked_models.setCurrentIndex(current_index)

    def _on_lst_models_item_change(self, item: QListWidgetItem) -> None:
        """Called when the current item has been changed, which, in our case, means checked or unchecked.

        Args:
            item: The item that was changed.
        """
        # Get the model page
        if self._model_str:  # Doing new model, not new sim
            model_page = self.ui.stacked_models.widget(0)
        else:
            model_index = self.ui.lst_models.indexFromItem(item)
            current_index = model_index.row()
            model_page = self.ui.stacked_models.widget(current_index)

            # Set the current item to the item that was checked
            self.ui.lst_models.setCurrentRow(current_index)

        # Enable/disable the stacked widget model page for the associated model list item
        if model_page:
            model_page.setEnabled(item.checkState() == Qt.Checked)

    def _model_checked(self, model: str) -> bool:
        """Return True if the model is checked."""
        model_idx = NewSimDialog._model_idxs[model]
        return self.ui.lst_models.item(model_idx).checkState() == Qt.Checked

    def eventFilter(self, obj, event):  # noqa N802 (function name 'eventFilter' should be lowercase)
        """Prevent user from entering forbidden characters into the sim name.

        Args:
            obj: Qt object associated with the event.
            event: The QEvent object.

        Returns:
            (bool): True if the event was handled.
        """
        if obj == self.ui.edt_simulation_name and event.type() == QEvent.KeyPress:
            bad_chars = '/\\!*?"<>| '  # Project Explorer names can't have these
            letter = event.text()
            if letter and letter in bad_chars:
                bad_chars_msg = '/ \\ ! * ? " < > |'
                msg = f"The name cannot contain spaces or any of the following characters:\n\t{bad_chars_msg}"
                QToolTip.showText(self.ui.edt_simulation_name.mapToGlobal(QPoint()), msg)
                return True
            return False
        return False

    def add_package(self, packages, checkbox, ftype):
        """Helper function to append a package to the list if the checkbox is checked.

        Args:
            packages (set(str)): List of packages.
            checkbox (QCheckBox): The QCheckbox.
            ftype (`str`): The MODFLOW ftype.
        """
        if checkbox.isChecked():
            packages.add(ftype)

    def update_data(self):
        """Return a list of strings indicating which packages are selected.

        Returns:
            (dict): The dialog data.
        """
        self.data = {
            'name': self.ui.edt_simulation_name.text(),
            'ugrid_uuid': self.ui_tree_ugrids.get_selected_item_uuid(),
        }
        if self._model_str == 'GWF' or (not self._model_str and self._model_checked('GWF')):
            packages = {'GWF6', 'DIS6', 'IC6', 'NPF6', 'OC6'}  # Required
            self.add_package(packages, self.ui.chk_buy, 'BUY6')
            self.add_package(packages, self.ui.chk_chd, 'CHD6')
            self.add_package(packages, self.ui.chk_csub, 'CSUB6')
            self.add_package(packages, self.ui.chk_drn, 'DRN6')
            self.add_package(packages, self.ui.chk_evt, 'EVT6')
            self.add_package(packages, self.ui.chk_evta, 'EVTA6')
            self.add_package(packages, self.ui.chk_ghb, 'GHB6')
            self.add_package(packages, self.ui.chk_gnc, 'GNC6')
            self.add_package(packages, self.ui.chk_hfb, 'HFB6')
            self.add_package(packages, self.ui.chk_lak, 'LAK6')
            self.add_package(packages, self.ui.chk_maw, 'MAW6')
            self.add_package(packages, self.ui.chk_mvr, 'MVR6')
            self.add_package(packages, self.ui.chk_obs, 'OBS6')
            self.add_package(packages, self.ui.chk_rch, 'RCH6')
            self.add_package(packages, self.ui.chk_rcha, 'RCHA6')
            self.add_package(packages, self.ui.chk_riv, 'RIV6')
            self.add_package(packages, self.ui.chk_sfr, 'SFR6')
            self.add_package(packages, self.ui.chk_sto, 'STO6')
            self.add_package(packages, self.ui.chk_swi, 'SWI6')
            self.add_package(packages, self.ui.chk_uzf, 'UZF6')
            self.add_package(packages, self.ui.chk_vsc, 'VSC6')
            self.add_package(packages, self.ui.chk_wel, 'WEL6')
            self.add_package(packages, self.ui.chk_pobs, 'POBS6')
            self.add_package(packages, self.ui.chk_zones, 'ZONE6')
            self.data['gwf_ftypes'] = packages

        if self._model_str == 'GWT' or (not self._model_str and self._model_checked('GWT')):
            packages = {'GWT6', 'DIS6', 'IC6', 'MST6', 'OC6'}  # Required
            self.add_package(packages, self.ui.chk_adv, 'ADV6')
            self.add_package(packages, self.ui.chk_cnc, 'CNC6')
            self.add_package(packages, self.ui.chk_dsp, 'DSP6')
            self.add_package(packages, self.ui.chk_fmi, 'FMI6')
            self.add_package(packages, self.ui.chk_ist, 'IST6')
            self.add_package(packages, self.ui.chk_lkt, 'LKT6')
            self.add_package(packages, self.ui.chk_mdt, 'MDT6')
            self.add_package(packages, self.ui.chk_mvt, 'MVT6')
            self.add_package(packages, self.ui.chk_mwt, 'MWT6')
            self.add_package(packages, self.ui.chk_obs_gwt, 'OBS6')
            self.add_package(packages, self.ui.chk_sft, 'SFT6')
            self.add_package(packages, self.ui.chk_src, 'SRC6')
            self.add_package(packages, self.ui.chk_ssm, 'SSM6')
            self.add_package(packages, self.ui.chk_uzt, 'UZT6')
            self.add_package(packages, self.ui.chk_pobs_gwt, 'POBS6')
            self.data['gwt_ftypes'] = packages

        if self._model_str == 'GWE' or (not self._model_str and self._model_checked('GWE')):
            packages = {'GWE6', 'DIS6', 'IC6', 'OC6'}  # Required
            self.add_package(packages, self.ui.chk_adv_gwe, 'ADV6')
            self.add_package(packages, self.ui.chk_cnd, 'CND6')
            self.add_package(packages, self.ui.chk_ctp, 'CTP6')
            self.add_package(packages, self.ui.chk_esl, 'ESL6')
            self.add_package(packages, self.ui.chk_est, 'EST6')
            self.add_package(packages, self.ui.chk_fmi_gwe, 'FMI6')
            self.add_package(packages, self.ui.chk_lke, 'LKE6')
            self.add_package(packages, self.ui.chk_mve, 'MVE6')
            self.add_package(packages, self.ui.chk_mwe, 'MWE6')
            self.add_package(packages, self.ui.chk_obs_gwe, 'OBS6')
            self.add_package(packages, self.ui.chk_sfe, 'SFE6')
            self.add_package(packages, self.ui.chk_ssm_gwe, 'SSM6')
            self.add_package(packages, self.ui.chk_uze, 'UZE6')
            self.add_package(packages, self.ui.chk_pobs_gwe, 'POBS6')
            self.data['gwe_ftypes'] = packages

        if util.PRT_INTERFACE:
            if self._model_str == 'PRT' or (not self._model_str and self._model_checked('PRT')):
                packages = {'PRT6', 'DIS6', 'MIP6'}  # Required
                self.add_package(packages, self.ui.chk_prp, 'PRP6')
                self.add_package(packages, self.ui.chk_oc_prt, 'OC6')
                self.add_package(packages, self.ui.chk_fmi_gwe, 'FMI6')
                self.data['prt_ftypes'] = packages

    def _non_unique_name_warning(self) -> str:
        """Return a warning message if the name is not unique among its siblings.

        Returns:
            See description.
        """
        name = self.ui.edt_simulation_name.text()
        name = ''.join(name.split())  # Remove whitespace
        node = tree_util.find_tree_node_by_uuid(self._project_tree, self._new_uuid)
        for child in node.parent.children:
            if child != node and child.name == name:
                if self._model_str:
                    return f'Name "{name}" matches an existing model. Please provide a unique name.'
                else:
                    return f'Name "{name}" matches an existing simulation. Please provide a unique name.'
        return ''

    def _grid_missing_tops_or_bottoms_warning(self) -> str:
        """Return an error message if the grid lacks 'Cell Top Z' or 'Cell Bottom Z' datasets, else return ''.

        Returns:
            See description.
        """
        uuid = self.ui_tree_ugrids.get_selected_item_uuid()
        grid_node = tree_util.find_tree_node_by_uuid(self._project_tree, uuid)
        cell_tops_node = tree_util.first_descendant_with_name(grid_node, 'Cell Top Z')
        cell_bottoms_node = tree_util.first_descendant_with_name(grid_node, 'Cell Bottom Z')
        name = grid_node.name
        msg = ''
        if not cell_tops_node or not cell_bottoms_node:
            msg = f'Grid "{name}" cannot be used because it is missing "Cell Top Z" or "Cell Bottom Z" datasets.'
        return msg

    def _validate(self) -> str:
        """Check for errors and return an error message if there's a problem, or '' if all is OK.

        Returns:
            See description.
        """
        # No UGrid selected
        uuid = self.ui_tree_ugrids.get_selected_item_uuid()
        if not uuid:
            return 'Please select a UGrid to create a new simulation.'

        # Name is the same as a sibling
        msg = self._non_unique_name_warning()
        if msg:
            return msg

        # UGrid missing cell tops or bottoms datasets
        msg = self._grid_missing_tops_or_bottoms_warning()
        if msg:
            return msg

        return ''

    def accept(self):
        """Called when the OK button is clicked. Makes sure a Ugrid is selected."""
        msg = self._validate()
        if msg:
            message_box.message_with_ok(parent=self, message=msg)
            return

        self.update_data()
        super().accept()


def _find_good_tree_name(node: TreeNode, starting_name: str) -> str:
    """Return a tree name that is unique amongst it's siblings, e.g. 'flow_2', 'sim_2' etc.

    Args:
        node: Tree node of the node whose name we are dealing with.
        starting_name: The name we build on.

    Returns:
        See description.
    """
    if not node:
        return ''
    new_pattern = f'{starting_name}_$*$'
    candidate = starting_name
    i = 2
    done = False
    while not done:
        done = True
        for child in node.parent.children:
            if child == node:
                continue
            if child.name == candidate:
                candidate = new_pattern.replace('_$*$', f'_{i}')
                i += 1
                done = False
                break
    return candidate
