"""CbcScalarDatasetCreator class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os
from pathlib import Path

# 2. Third party modules

# 3. Aquaveo modules
from xms.api._xmsapi.dmi import SimulationItem
from xms.api.tree import tree_util, TreeNode
from xms.datasets.dataset_writer import DatasetWriter
from xms.guipy.testing import testing_tools

# 4. Local modules
from xms.mf6.components import dmi_util
from xms.mf6.data.grid_info import DisEnum
from xms.mf6.file_io import io_util, tdis_reader
from xms.mf6.gui import units_util


class CbcDatasetCreatorBase:
    """Creates datasets from the cbc file."""
    def __init__(self, cbc_filepath: Path | str, model_node: TreeNode, query):
        """Initializes the class.

        Args:
            cbc_filepath (str): File path of cbc file.
            model_node (TreeNode): Model tree node.
            query (xmsapi.dmi.Query): Object for communicating with GMS
        """
        self._cbc_filepath = cbc_filepath
        self._model_node = model_node
        self._query = query

        self._cbc_node = None
        self._dis_node = None
        self._sim_node = None
        self._cell_count = 0
        self._ugrid_uuid = ''
        self._time_units = ''
        self._cbc_folder = ''
        self._tree_folder = ''

    def _base_setup(self):
        """Do some setup here."""
        self._find_cbc_node()
        self._find_dis_node()
        self._find_sim_node()
        self._find_cell_count()
        self._find_time_units()
        self._find_ugrid_uuid()
        self._find_cbc_folder()
        self._find_tree_folder()

    def _find_cbc_node(self) -> None:
        self._cbc_node = tree_util.find_tree_node_by_uuid(self._query.project_tree, self._query.current_item_uuid())

    def _find_dis_node(self) -> None:
        """Finds the DIS* package."""
        self._dis_node = tree_util.child_with_unique_name_in_list(self._model_node, ['DIS6', 'DISV6', 'DISU6'])

    def _find_sim_node(self) -> None:
        """Finds the simulation tree item."""
        self._sim_node = tree_util.ancestor_of_type(self._model_node, SimulationItem)

    def _find_cell_count(self) -> None:
        """Finds the number of cells in the grid and stores it in a member variable."""
        self._cell_count = self._get_cell_count()

    def _find_time_units(self) -> None:
        self._time_units = self._get_time_units()

    def _find_ugrid_uuid(self) -> None:
        """Finds the UGrid uuid and stores it in a member variable."""
        self._ugrid_uuid = dmi_util.ugrid_uuid_from_model_node(self._model_node)
        if not self._ugrid_uuid:
            raise RuntimeError('Could not find UGrid. No datasets created.')

    def _find_cbc_folder(self) -> None:
        """Finds the directory of the .cbc file."""
        self._cbc_folder = find_cbc_folder(self._cbc_filepath)

    def _find_tree_folder(self):
        self._tree_folder = self._tree_folder_name()

    def _get_cell_count(self) -> int:
        """Returns the number of cells in the grid.

        Returns:
            (int): See description.
        """
        grid_info = io_util.grid_info_from_dis_file(self._dis_node.main_file)
        if not grid_info or grid_info.dis_enum not in [DisEnum.DIS, DisEnum.DISV, DisEnum.DISU]:
            raise RuntimeError('Unable to find valid DIS package. Aborting.')
        return grid_info.cell_count()

    def _get_time_units(self):
        """Returns the time units as defined in the TDIS file.

        Returns:
            See description.
        """
        tdis_node = tree_util.child_with_unique_name_in_list(self._sim_node, ['TDIS6'])
        time_units, _ = tdis_reader.time_units_and_start_date_time(tdis_node.main_file)
        return units_util.dataset_time_units_from_tdis_time_units(time_units)

    def _init_dataset_writer(self, name: str, num_components: int) -> DatasetWriter:
        """Initializes an HDF5 dataset file.

        Args:
            name (str): Name of the dataset.
            num_components (int): 3 for vector, 1 for scalar.

        Returns:
            (DatasetWriter): The dataset writer.
        """
        filename = os.path.join(self._cbc_folder, f'{name}.h5')
        dset_uuid = testing_tools.new_uuid()
        return DatasetWriter(
            h5_filename=filename,
            name=name,
            geom_uuid=self._ugrid_uuid,
            num_components=num_components,
            time_units=self._time_units,
            location='cells',
            use_activity_as_null=True,
            dset_uuid=dset_uuid
        )

    def _tree_folder_name(self) -> str:
        """Returns the name of the folder where the datasets will be created."""
        return f'{self._sim_node.name} (MODFLOW 6) GMS/{self._model_node.name}'


def find_cbc_folder(cbc_filepath: Path | str) -> str:
    """Return the cbc folder given the file path.

    This exists so it can be mocked in tests.

    Returns:
        See description.
    """
    return os.path.dirname(cbc_filepath)
