"""This handles saving and loading transport constituents."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import importlib.metadata
import os

# 2. Third party modules
from adhparam.constituent_properties import ConstituentProperties
from adhparam.time_series import TimeSeries
import h5py
import numpy
import pandas
import xarray as xr

# 3. Aquaveo modules
from xms.components.bases.xarray_base import XarrayBase

# 4. Local modules
from xms.adh.data import param_h5_io


class TransportConstituentsIO(XarrayBase):
    """A class that handles saving and loading transport constituent data."""
    def __init__(self, main_file):
        """Initializes the data class.

        Args:
            main_file: The main file associated with this component.
        """
        super().__init__(main_file)
        self.main_file = main_file
        attrs = self.info.attrs
        defaults = self.default_data()
        for key in defaults.keys():
            if key not in attrs:
                attrs[key] = defaults[key]
        attrs['FILE_TYPE'] = 'ADH_TRANSPORT_CONSTITUENTS'
        self.param_control = ConstituentProperties()
        self.user_constituents = pandas.DataFrame(data=[], columns=['ID', 'NAME', 'CONC']).to_xarray()
        self.series_ids = xr.Dataset()
        self.short_wave_series_id = 1
        self.dew_point_series_id = 2
        self.time_series = {self.short_wave_series_id: TimeSeries(), self.dew_point_series_id: TimeSeries()}
        self.domain_uuid = ''
        if os.path.exists(main_file):
            file_type = 'ADH_TRANSPORT_CONSTITUENTS'
            param_h5_io.read_from_h5_file(main_file, self.param_control, file_type)
            grp_name = 'uuids'
            uuid_info = xr.load_dataset(self.main_file, group=grp_name)
            self.domain_uuid = uuid_info.attrs['domain_uuid']
            self.read_time_series_from_h5(main_file)
            self.user_constituents = xr.load_dataset(main_file, group='user_constituents')

    @staticmethod
    def default_data():
        """Gets the default data for this class.

        Returns:
            A dictionary of default values that will go into the info dataset attrs.
        """
        version = importlib.metadata.version('xmsadh')
        return {
            'FILE_TYPE': 'ADH_TRANSPORT_CONSTITUENTS',
            'VERSION': version,
            'next_constituent_id': 4,  # constituent before this are: salinity - 1, temperature - 2, vorticity - 3
        }

    def read_time_series_from_h5(self, filename):
        """Reads the time series from the file.

        Args:
            filename (str): The file to read time series data from.
        """
        series_ids = '/time_series/ids'
        self.series_ids = xr.load_dataset(filename, group=series_ids)
        for t_id in self.series_ids['ids']:
            if t_id and t_id > 0 and numpy.isnan(t_id) is False:
                int_id = int(t_id)
                if int_id not in self.time_series:
                    self.time_series[int_id] = TimeSeries()
                param_h5_io.read_params_recursive(
                    filename, group_name=f'/time_series/{int_id}/', param_class=self.time_series[int_id]
                )

    def commit(self):
        """Saves the component data to a netCDF file."""
        self.info.attrs['VERSION'] = importlib.metadata.version('xmsadh')
        super().commit()
        param_h5_io.write_params_recursive(self.main_file, '/', self.param_control)
        uuid_info = xr.Dataset()
        grp_name = 'uuids'
        uuid_info.attrs['domain_uuid'] = self.domain_uuid
        uuid_info.to_netcdf(self.main_file, group=grp_name, mode='a')

        with h5py.File(self.main_file, 'a') as f:
            try:
                del f['user_constituents']
            except Exception:
                pass
        self.user_constituents.to_netcdf(self.main_file, group='user_constituents', mode='a')

        id_list = list(self.time_series.keys())
        if not id_list:
            id_list = [[None]]
        else:
            id_list = [[series_id] for series_id in id_list]
        self.series_ids = pandas.DataFrame(id_list, columns=['ids']).to_xarray()
        with h5py.File(self.main_file, 'a') as f:
            try:
                del f['/time_series/ids']
            except Exception:
                pass
        self.series_ids.to_netcdf(self.main_file, group='/time_series/ids', mode='a')
        for t_id in self.time_series.keys():
            if t_id and self.time_series[t_id]:
                param_h5_io.write_params_recursive(
                    self.main_file, group_name=f'/time_series/{t_id}/', param_class=self.time_series[t_id]
                )
