"""This module for writes datasets for CMS-Flow."""

# 1. Standard Python modules
import os
from typing import Any

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.api.tree import tree_util
from xms.datasets.dataset_writer import DatasetWriter

# 4. Local modules
from xms.cmsflow.data.simulation_data import SimulationData


class CMSFlowDatasetsExporter:
    """Exporter for the CMS-Flow _datasets.h5 file.

    """
    def __init__(self):
        """Constructor for class."""
        # uuid -> (time_units, num_components, times, values)
        self.datasets: dict[str, tuple[Any, Any, Any, Any]] = {}
        self._pe_tree = None

    def _add_dataset(self, query: Query, uuids: str | list[str], invert: bool = False):
        """Query SMS for a dataset by the datasets uuid and add to map using dataset name as key if it exists.

        Args:
            query: Query object for communicating with SMS.
            uuids: UUIDs of the datasets to add.
            invert: Whether to invert the dataset values before writing. SMS and CMS-Flow disagree in some cases whether
                a dataset should contain depths or elevations. This parameter may be used to convert them.
        """
        if not isinstance(uuids, list):
            uuids = [uuids]

        for dset_uuid in uuids:
            if not isinstance(dset_uuid, str):
                continue
            dataset = query.item_with_uuid(dset_uuid)
            if not dataset:
                continue

            dset_name = dataset.name
            if self._pe_tree:
                # The reader has a dataset name on it, but it isn't always reliable. If you duplicate a dataset in SMS,
                # the name for it on the reader is still the name of the original dataset. The tree node's name seems to
                # be more reliable, so we try to use that instead.
                dset_node = tree_util.find_tree_node_by_uuid(self._pe_tree, dset_uuid)
                # The dataset exists, and all datasets have tree nodes, so its tree node must exist too.
                dset_name = dset_node.name

            values = np.array(dataset.values, dtype=float)

            if invert:
                mask = values != -999
                np.multiply(-1, values, out=values, where=mask)

            self.datasets[dset_name] = (dataset.time_units, dataset.num_components, dataset.times, values)

    def export_datasets(self, query=None):
        """Write the CMS-Flow _datasets.h5 file.

        Args:
            query (Query): The interprocess communication object. If not supplied, will create connection.
        """
        query = query if query else Query()
        proj_name = os.path.splitext(os.path.basename(query.xms_project_path))[0]
        self._pe_tree = query.project_tree

        # get a dump of the Model Control widgets.
        sim_item = tree_util.find_tree_node_by_uuid(self._pe_tree, query.current_item_uuid())
        sim_comp = query.item_with_uuid(sim_item.uuid, model_name='CMS-Flow', unique_name='Simulation_Component')
        sim_data = SimulationData(sim_comp.main_file)

        # get the Quadtree geometry.
        quad_item = tree_util.descendants_of_type(
            sim_item, xms_types=['TI_UGRID_PTR', 'TI_CGRID2D_PTR'], recurse=False, allow_pointers=True, only_first=True
        )
        if not quad_item:
            return [('ERROR', 'Could not find the geometry for the simulation.')], []
        quad_uuid = quad_item.uuid

        # Build up a map of unique datasets to be written to the file.
        self._add_dataset(query, sim_data.flow.attrs['BOTTOM_ROUGHNESS_DSET'])

        if sim_data.salinity.attrs['SALINITY_CONCENTRATION'] == 'Spatially varied':
            self._add_dataset(query, sim_data.salinity.attrs['SALINITY_INITIAL_CONCENTRATION'])
        if sim_data.salinity.attrs['INITIAL_TEMPERATURE_TYPE'] == 'Spatially varied':
            self._add_dataset(query, sim_data.salinity.attrs['INITIAL_TEMPERATURE_DATASET'])

        if sim_data.wave.attrs['WAVE_INFO'] == 'Single wave condition':
            self._add_dataset(query, sim_data.wave.attrs['WAVE_HEIGHT'])
            self._add_dataset(query, sim_data.wave.attrs['PEAK_PERIOD'])
            self._add_dataset(query, sim_data.wave.attrs['MEAN_WAVE_DIR'])
            self._add_dataset(query, sim_data.wave.attrs['WAVE_BREAKING'])
            self._add_dataset(query, sim_data.wave.attrs['WAVE_RADIATION'])
        # self.add_dataset(sim_data.general.attrs['dsetSurfaceRoller'])  # experimental

        if sim_data.sediment.attrs['CALCULATE_SEDIMENT'] == 1:
            if sim_data.sediment.attrs['USE_HARD_BOTTOM'] == 1:
                self._add_dataset(query, sim_data.sediment.attrs['HARD_BOTTOM'], invert=True)

            if sim_data.sediment.attrs['FORMULATION_UNITS'] == 'Nonequilibrium total load':
                if sim_data.sediment.attrs['ENABLE_SIMPLIFIED_MULTI_GRAIN_SIZE'] == 1:
                    if sim_data.sediment.attrs['BED_COMPOSITION_INPUT'] == 'D50 Sigma':
                        self._add_dataset(query, sim_data.sediment.attrs['MULTI_D50'])
                    elif sim_data.sediment.attrs['BED_COMPOSITION_INPUT'] == 'D16 D50 D84':
                        self._add_dataset(query, sim_data.sediment.attrs['MULTI_D16'])
                        self._add_dataset(query, sim_data.sediment.attrs['MULTI_D50'])
                        self._add_dataset(query, sim_data.sediment.attrs['MULTI_D84'])
                    elif sim_data.sediment.attrs['BED_COMPOSITION_INPUT'] == 'D35 D50 D90':
                        self._add_dataset(query, sim_data.sediment.attrs['MULTI_D35'])
                        self._add_dataset(query, sim_data.sediment.attrs['MULTI_D50'])
                        self._add_dataset(query, sim_data.sediment.attrs['MULTI_D90'])
                else:
                    # table datasets
                    self._add_dataset(query, sim_data.bed_layer_table.layer_thickness.data.tolist())
                    self._add_dataset(query, sim_data.bed_layer_table.d05.data.tolist())
                    self._add_dataset(query, sim_data.bed_layer_table.d10.data.tolist())
                    self._add_dataset(query, sim_data.bed_layer_table.d16.data.tolist())
                    self._add_dataset(query, sim_data.bed_layer_table.d20.data.tolist())
                    self._add_dataset(query, sim_data.bed_layer_table.d30.data.tolist())
                    self._add_dataset(query, sim_data.bed_layer_table.d35.data.tolist())
                    self._add_dataset(query, sim_data.bed_layer_table.d50.data.tolist())
                    self._add_dataset(query, sim_data.bed_layer_table.d65.data.tolist())
                    self._add_dataset(query, sim_data.bed_layer_table.d84.data.tolist())
                    self._add_dataset(query, sim_data.bed_layer_table.d90.data.tolist())
                    self._add_dataset(query, sim_data.bed_layer_table.d95.data.tolist())

        self._write_datasets_file('_datasets', quad_uuid, proj_name)

        # write the dredge and placement dataset files
        if sim_data.dredge.attrs['ENABLE_DREDGE'] == 1:
            self.datasets = {}
            if quad_uuid:
                self._add_dataset(query, sim_data.dredge.attrs['DREDGE_DATASET'])
            self._write_datasets_file('_DredgeArea', quad_uuid, proj_name)

            self.datasets = {}
            if quad_uuid:
                if sim_data.dredge_placement.attrs['DEFINE_PLACEMENT_1'] == 1:
                    self._add_dataset(query, sim_data.dredge_placement.attrs['PLACEMENT_1_DATASET'])
                if sim_data.dredge_placement.attrs['DEFINE_PLACEMENT_2'] == 1:
                    self._add_dataset(query, sim_data.dredge_placement.attrs['PLACEMENT_2_DATASET'])
                if sim_data.dredge_placement.attrs['DEFINE_PLACEMENT_3'] == 1:
                    self._add_dataset(query, sim_data.dredge_placement.attrs['PLACEMENT_3_DATASET'])
            self._write_datasets_file('_PlacementArea', quad_uuid, proj_name)

    def _write_datasets_file(self, file_name, quad_uuid, proj_name):
        """Writes the dataset file.

        Args:
            file_name (str): The name of the file to write.
            quad_uuid (str): The UUID of the quadtree geometry.
            proj_name (str): The name of the project this belongs to.
        """
        if quad_uuid:
            # Write the unique datasets to the file.
            overwrite = True  # overwrite the file with the first dataset
            for dset_name, (time_units, num_components, times, values) in self.datasets.items():
                writer = DatasetWriter(
                    f'{proj_name}{file_name}.h5',
                    name=dset_name,
                    geom_uuid=quad_uuid,
                    overwrite=overwrite,
                    null_value=-999.0,
                    time_units=time_units,
                    location='cells',
                    num_components=num_components
                )
                overwrite = False
                writer.write_xmdf_dataset(times=times, data=values)
