"""XmsData class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.api.tree import tree_util
from xms.constraint import Grid, read_grid_from_file
from xms.data_objects.parameters import UGrid as DoGrid
from xms.datasets import dataset_io
from xms.datasets.dataset_metadata import DatasetMetadata

# 4. Local modules
from xms.mf6.components import dmi_util
from xms.mf6.file_io import io_util
from xms.mf6.misc import log_util


class XmsData:
    """Holds XMS data and Query."""
    def __init__(self, query: Query | None):
        """Initializer."""
        self._query = query
        self._dogrids: dict[str, DoGrid] = {}  # str is uuid of model tree node
        self._cogrids: dict[str, Grid] = {}  # str is uuid of model tree node

    @property
    def query(self) -> Query:
        """Return the query.

        Returns:
            See description.
        """
        return self._query

    @query.setter
    def query(self, query: Query) -> None:
        """Set the query.

        Args:
            query: Object for communicating with GMS
        """
        self._query = query

    def get_dogrid(self, model_uuid: str) -> DoGrid | None:
        """Return the data_objects grid associated with the model.

        Args:
            model_uuid: Uuid of the model that the grid is linked to.

        Returns:
            See description.
        """
        if not model_uuid:
            return None

        # First see if we have it cached
        dogrid = self._dogrids.get(model_uuid)
        if dogrid:
            return dogrid
        elif self._query is not None:
            # Get it from the query, which takes time as it has to be saved to disk from XMS
            model_node = tree_util.find_tree_node_by_uuid(self._query.project_tree, model_uuid)
            ugrid_uuid = dmi_util.ugrid_uuid_from_model_node(model_node)
            dogrid = self._query.item_with_uuid(ugrid_uuid)  # Returns a DoGrid
            self._dogrids[model_uuid] = dogrid
        return dogrid

    def set_dogrid(self, model_uuid: str, dogrid: DoGrid) -> None:
        """Set the data_objects grid associated with the model.

        Args:
            model_uuid: Uuid of the model that the grid is linked to.
            dogrid: The data_objects grid.
        """
        self._dogrids[model_uuid] = dogrid

    def get_cogrid(self, model_uuid: str) -> Grid | None:
        """Return the constrained grid associated with the model.

        Args:
            model_uuid: Uuid of the model that the grid is linked to.

        Returns:
            See description.
        """
        if not model_uuid:
            return None

        # First see if we have it cached
        cogrid = self._cogrids.get(model_uuid)
        if not cogrid:
            # Get the dogrid, which will either be cached, or obtained from XMS via the query
            dogrid = self.get_dogrid(model_uuid)
            if dogrid:
                cogrid = read_grid_from_file(dogrid.cogrid_file)
                self._cogrids[model_uuid] = cogrid
        return cogrid

    def set_cogrid(self, model_uuid: str, cogrid: Grid) -> None:
        """Set the cogrid associated with the model.

        Args:
            model_uuid: Uuid of the model that the grid is linked to.
            cogrid: The constrained grid.
        """
        self._cogrids[model_uuid] = cogrid


def update_ugrid_values(ugrid_uuid, tops, bottoms, idomain, query):
    """Updates the UGrid top and bottom elevations from the MODFLOW array data.

    Args:
        ugrid_uuid (str): uuid of the UGrid.
        tops (list of float, optional): Top elevations.
        bottoms (list of float, optional): Bottom elevations.
        idomain (list of float, optional): IDOMAIN array.
        query (xmsapi.dmi.Query): Object for communicating with GMS
    """
    if not ugrid_uuid:
        log_util.get_logger().error('No UGrid. Updating UGrid values aborted.')
        return
    # Set the UGrid UUID of the geometry we want to edit and the list of dataset files to read.
    metadata = DatasetMetadata()
    metadata.geom_uuid = ugrid_uuid
    metadata.value_files = []

    if tops:
        top_file = io_util.get_temp_filename()
        dataset_io.write_dimension_values(top_file, tops)
        metadata.value_files.append({'type': 'top', 'file': top_file})
    if bottoms:
        bottom_file = io_util.get_temp_filename()
        dataset_io.write_dimension_values(bottom_file, bottoms)
        metadata.value_files.append({'type': 'bottom', 'file': bottom_file})
    if idomain:
        if not isinstance(idomain[0], float):  # Convert to float if necessary
            idomain = [float(x) for x in idomain]
        model_on_off_file = io_util.get_temp_filename()
        dataset_io.write_dimension_values(model_on_off_file, idomain)
        metadata.value_files.append({'type': 'model_on_off', 'file': model_on_off_file})

    # Write the json metadata file
    json_file = os.path.join(os.getcwd(), 'cell.elevs.json')
    dataset_io.write_dset_metadata_to_json(json_file, metadata)

    # Add the json filename to the Query
    if query is not None:
        query.edit_ugrid_top_bottom_elevations(json_file)
