"""DisComponentBase class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import copy
import os

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.api.tree import tree_util, TreeNode

# 4. Local modules
from xms.mf6.components import dmi_util
from xms.mf6.components.package_component_base import PackageComponentBase
from xms.mf6.data import data_util
from xms.mf6.data.base_file_data import BaseFileData
from xms.mf6.data.dis_data_base import DisDataBase
from xms.mf6.file_io.writer_options import WriterOptions


class DisComponentBase(PackageComponentBase):
    """A Dynamic Model Interface (DMI) component for DIS package data."""
    def __init__(self, main_file):
        """Initializes the class.

        Args:
            main_file: The main file associated with this component.
        """
        super().__init__(main_file.strip('\'"'))

    def delete_event(self, lock_state):
        """This will be called when the component is deleted.

        Args:
            lock_state (bool): True if the component is locked for editing. Do not change or delete the files if
                locked.

        Returns:
            (tuple): tuple containing:
                - messages (list of tuple of str): List of tuples with the first element of the
                  tuple being the message level (DEBUG, ERROR, WARNING, INFO) and the second element being the message
                  text.
                - action_requests (list of xmsapi.dmi.ActionRequest): List of actions for XMS to perform.
        """
        # Nothing to do, override base class functionality
        return [], []


def update_dis_packages(from_dis: DisDataBase, ugrid_uuid: str, query: Query):
    """Updates the Dis package data when the data of another DIS package using the same UGrid is changed.

    Args:
        from_dis: The dis package to copy from.
        ugrid_uuid: uuid of the UGrid.
        query: GMS communicator
    """
    skip_dis_uuid = from_dis.tree_node.uuid

    # Find all the Linked UGrid pointers
    project_tree = query.project_tree
    links = tree_util.descendants_of_type(project_tree, xms_types=['TI_UGRID_PTR'], allow_pointers=True)
    for link in links:
        if link.uuid != ugrid_uuid:
            continue  # Not the same UGrid
        parent = link.parent
        if parent.item_typename != 'TI_COMPONENT':
            continue  # Not the right type of parent
        if parent.model_name != 'MODFLOW 6' or parent.unique_name not in data_util.model_ftypes():
            continue  # Not the right type of component parent

        to_dis_node = _find_dis_node(parent)
        if not to_dis_node:
            continue  # No DIS package?
        if to_dis_node.uuid == skip_dis_uuid:
            continue  # This is the DIS package we are copying values from
        if dmi_util.is_locked(to_dis_node, query):
            continue

        # Copy arrays
        ftype = to_dis_node.unique_name
        to_dis = BaseFileData.from_file(to_dis_node.main_file, ftype=ftype)
        for array_name in ['TOP', 'BOT', 'BOTM', 'IDOMAIN']:
            to_dis.block('GRIDDATA').delete_array(array_name)
            from_array = from_dis.block('GRIDDATA').array(array_name)
            if from_array:
                to_dis.block('GRIDDATA').add_array(copy.deepcopy(from_array))

        # Write to disk
        options = WriterOptions(
            mfsim_dir=os.path.dirname(parent.parent.main_file),  # Parent of the model is the simulation
            use_open_close=True,
            dmi_sim_dir=os.path.dirname(os.path.dirname(to_dis_node.main_file))
        )
        to_dis.write(options)


def _find_dis_node(model_node: TreeNode) -> TreeNode | None:
    """Find the DIS package TreeNode of a model.

    Args:
        model_node: The model tree node (parent of the DIS package).

    Returns:
        (tuple[TreeNode, str]): DIS package tree node and ftype or None, None if not found
    """
    xms_types, ftypes = ['TI_COMPONENT'], ['DIS6', 'DISV6', 'DISU6']
    descendants = tree_util.descendants_of_type  # for short
    for ftype in ftypes:
        dis = descendants(model_node, xms_types=xms_types, unique_name=ftype, model_name='MODFLOW 6', only_first=True)
        if dis:
            break
    else:
        return None
    return dis
