"""Stores the data for the boundary conditions."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import importlib.metadata
import os

# 2. Third party modules
from adhparam.boundary_conditions import BoundaryConditions
import h5py
import numpy
import pandas as pd
import xarray as xr

# 3. Aquaveo modules
from xms.components.bases.xarray_base import XarrayBase

# 4. Local modules
from xms.adh.data.json_export import (filtered_info_attrs, include_dataframe_if_not_empty, include_if_xarray_not_empty,
                                      make_json_serializable, parameterized_to_dict)
from xms.adh.data.lazy_time_series import LazyTimeSeries
from xms.adh.data.param_h5_io import read_params_recursive, write_params_recursive


class BcIO(XarrayBase):
    """A class that handles saving and loading boundary condition data."""

    def __init__(self, main_file):
        """Initializes the data class.

        Args:
            main_file: The main file associated with this component.
        """
        super().__init__(main_file)
        self.main_file = main_file
        attrs = self.info.attrs
        defaults = self.default_data()
        if not os.path.exists(main_file) or not os.path.isfile(main_file):
            self._info = xr.Dataset(attrs=defaults)
        self.arc_display = xr.Dataset()
        self.pt_display = xr.Dataset()
        self.series_ids = xr.Dataset()
        self.comp_id_columns = ['COMP_ID', 'FRICTION_ID', 'FLUX_ID', 'TRANSPORT_ID', 'DIVERSION_ID', 'BC_ID']
        self.comp_id_to_ids = pd.DataFrame([], columns=self.comp_id_columns)
        self.snap_types = pd.DataFrame([], columns=['ID', 'SNAP'])
        self.flux = pd.DataFrame([], columns=['ID', 'IS_FLUX', 'EDGESTRING', 'MIDSTRING'])
        self.transport_assignments = pd.DataFrame(
            [], columns=['TRAN_ID', 'CONSTITUENT_ID', 'NAME', 'TYPE', 'SERIES_ID', 'SNAPPING']
        )
        self.uses_transport = pd.DataFrame([], columns=['TRAN_ID', 'USES_TRANSPORT'])
        self.sediment_assignments = pd.DataFrame(
            [], columns=['TRAN_ID', 'CONSTITUENT_ID', 'NAME', 'TYPE', 'SERIES_ID', 'SNAPPING']
        )
        self.uses_sediment = pd.DataFrame([], columns=['TRAN_ID', 'USES_SEDIMENT'])
        self.sediment_diversions = pd.DataFrame([], columns=['DIV_ID', 'SNAPPING', 'TOP', 'BOTTOM', 'BOTTOM_MAIN'])
        self.nb_out = pd.DataFrame([], columns=['BC_ID', 'OUT_COMP_ID', 'IN_COMP_ID', 'SERIES_ID'])

        for key in defaults.keys():
            if key not in attrs:
                attrs[key] = defaults[key]
        attrs['FILE_TYPE'] = 'ADH_BC_DATA'
        self._bc = None
        self._main_file_read = False
        if os.path.exists(main_file):
            self._main_file_read = True
            self.read_from_h5_file(main_file, 'ADH_BC_DATA')

    @property
    def bc(self) -> BoundaryConditions:
        """Lazily loads the BoundaryConditions object.

        Returns:
            The boundary conditions object.
        """
        if self._bc is None:
            if self._main_file_read:
                self._read_bc_data()
            else:
                self._bc = BoundaryConditions()
        return self._bc

    def read_from_h5_file(self, filename, file_type):
        """Reads this class from an H5 file.

        Args:
            filename (str): file name
            file_type (str): The type of file as written in the netCDF file.
        """
        grp_name = 'info'
        info = xr.load_dataset(filename, group=grp_name)
        if 'FILE_TYPE' not in info.attrs:
            raise IOError(f'File attributes do not include file_type in file : {filename}')
        ftype = info.attrs['FILE_TYPE']
        if ftype != file_type:
            raise IOError(f'Unrecognized file_type "{ftype}" attribute in file: {filename}')

        self.read_name_and_display_from_h5(filename)
        self.read_time_series_from_h5(filename)
        self.snap_types = xr.load_dataset(filename, group='/snap_types').to_dataframe()
        self.flux = xr.load_dataset(filename, group='/flux').to_dataframe()
        self.nb_out = xr.load_dataset(filename, group='/nb_out').to_dataframe()
        self.read_transport_assignments_from_h5(filename)
        self.read_sediment_assignments_from_h5(filename)
        self.read_sediment_diversions_from_h5(filename)
        self.comp_id_to_ids = xr.load_dataset(filename, group='/comp_id_to_ids').to_dataframe()

    def read_name_and_display_from_h5(self, filename):
        """Reads names and display options of a category from a file.

        Args:
            filename (str): The name of the file to read from.
        """
        arc_grp_name = '/display/arc'
        pt_grp_name = '/display/point'
        # name, red, green, blue, alpha, style, size
        self.arc_display = xr.load_dataset(filename, group=arc_grp_name)
        self.pt_display = xr.load_dataset(filename, group=pt_grp_name)

    def read_transport_assignments_from_h5(self, filename):
        """Reads the transport assignments from the h5 file.

        Args:
            filename (str): The h5 file to read.
        """
        transport_grp_name = '/transport_constituent_assignments'
        self.transport_assignments = xr.load_dataset(filename, group=transport_grp_name).to_dataframe()
        use_grp_name = '/uses_transport'
        self.uses_transport = xr.load_dataset(filename, group=use_grp_name).to_dataframe()

    def read_sediment_assignments_from_h5(self, filename):
        """Reads the sediment assignments from the h5 file.

        Args:
            filename (str): The h5 file to read.
        """
        sediment_grp_name = '/sediment_constituent_assignments'
        self.sediment_assignments = xr.load_dataset(filename, group=sediment_grp_name).to_dataframe()
        use_sed_grp_name = '/uses_sediment'
        self.uses_sediment = xr.load_dataset(filename, group=use_sed_grp_name).to_dataframe()

    def read_sediment_diversions_from_h5(self, filename):
        """Reads the sediment diversions from the h5 file.

        Args:
            filename (str): The h5 file to read.
        """
        sediment_grp_name = '/sediment_diversions'
        self.sediment_diversions = xr.load_dataset(filename, group=sediment_grp_name).to_dataframe()

    def read_time_series_from_h5(self, filename):
        """Reads the time series from the netCDF file.

        Args:
            filename (str): The name of the netCDF file to read time series from.
        """
        series_ids = '/time_series/ids'
        self.series_ids = xr.load_dataset(filename, group=series_ids)

    def default_data(self):
        """Gets the default data for this class.

        Returns:
            A dictionary of default values that will go into the info dataset attrs.
        """
        version = importlib.metadata.version('xmsadh')
        return {
            'FILE_TYPE': 'ADH_BC_DATA',
            'VERSION': version,
            'arc_display_uuid': '',
            'point_display_uuid': '',
            'transport_uuid': '',
            'sediment_uuid': '',
            'cov_uuid': '',
            'next_comp_id': 0,
            'next_flux_id': 0,
            'next_friction_id': 0,
            'next_transport_id': 0,
            'next_diversion_id': 0,
            'next_bc_id': 0
        }

    @property
    def arc_display_uuid(self) -> str:
        """Arc display UUID getter.

        Returns:
            The arc display UUID.
        """
        return self.info.attrs["arc_display_uuid"]

    @arc_display_uuid.setter
    def arc_display_uuid(self, value: str):
        """Arc display UUID setter.

        Returns:
            The arc display UUID.
        """
        self.info.attrs["arc_display_uuid"] = value

    @property
    def point_display_uuid(self) -> str:
        """Point display UUID getter.

        Returns:
            The point display UUID.
        """
        return self.info.attrs["point_display_uuid"]

    @point_display_uuid.setter
    def point_display_uuid(self, value: str):
        """Point display UUID setter.

        Args:
            value (str): The point display UUID.
        """
        self.info.attrs["point_display_uuid"] = value

    @property
    def transport_uuid(self) -> str:
        """Transport UUID getter.

        Returns:
            The transport UUID.
        """
        return self.info.attrs["transport_uuid"]

    @transport_uuid.setter
    def transport_uuid(self, value: str):
        """Transport UUID setter.

        Args:
            value: The transport UUID.
        """
        self.info.attrs["transport_uuid"] = value

    @property
    def sediment_uuid(self) -> str:
        """Sediment UUID getter.

        Returns:
            The sediment UUID.
        """
        return self.info.attrs["sediment_uuid"]

    @sediment_uuid.setter
    def sediment_uuid(self, value: str):
        """Sediment UUID setter.

        Args:
            value (str): The sediment UUID.

        """
        self.info.attrs["sediment_uuid"] = value

    @property
    def cov_uuid(self) -> str:
        """Coverage UUID getter.

        Returns:
            The coverage UUID.
        """
        return self.info.attrs["cov_uuid"]

    @cov_uuid.setter
    def cov_uuid(self, value: str):
        """Coverage UUID setter.

        Args:
            value (str): The new UUID value to be set for the `cov_uuid`.
        """
        self.info.attrs["cov_uuid"] = value

    def commit(self):
        """Saves data to the main file."""
        self.info.attrs['VERSION'] = importlib.metadata.version('xmsadh')
        super().commit()
        write_params_recursive(self.main_file, group_name='/', param_class=self.bc)
        self.arc_display.to_netcdf(self.main_file, group='/display/arc', mode='a')
        self.pt_display.to_netcdf(self.main_file, group='/display/point', mode='a')

        self.clean(
            [
                '/comp_id_to_ids', '/snap_types', '/flux', '/time_series/ids', '/transport_constituent_assignments',
                '/sediment_constituent_assignments', '/uses_transport', '/uses_sediment', 'sediment_diversions',
                '/nb_out'
            ]
        )

        comp_id_data_set = self.comp_id_to_ids.to_xarray()
        snap_data_set = self.snap_types.to_xarray()
        flux_data_set = self.flux.to_xarray()
        nb_out_data_set = self.nb_out.to_xarray()

        comp_id_data_set = comp_id_data_set.assign_coords({'index': [i for i in range(len(comp_id_data_set.index))]})
        snap_data_set = snap_data_set.assign_coords({'index': [i for i in range(len(snap_data_set.index))]})
        flux_data_set = flux_data_set.assign_coords({'index': [i for i in range(len(flux_data_set.index))]})
        nb_out_data_set = nb_out_data_set.assign_coords({'index': [i for i in range(len(nb_out_data_set.index))]})

        comp_id_data_set.to_netcdf(self.main_file, group='/comp_id_to_ids', mode='a')
        snap_data_set.to_netcdf(self.main_file, group='/snap_types', mode='a')
        flux_data_set.to_netcdf(self.main_file, group='/flux', mode='a')
        nb_out_data_set.to_netcdf(self.main_file, group='/nb_out', mode='a')

        self._commit_transport()
        self._commit_time_series()

    def _commit_transport(self):
        """Commits transport related datasets."""
        # write the transport assignments
        transport_data_set = self.transport_assignments.to_xarray()
        transport_data_set = transport_data_set.assign_coords(
            {'index': [i for i in range(len(transport_data_set.index))]}
        )
        transport_data_set.to_netcdf(self.main_file, group='/transport_constituent_assignments', mode='a')

        # write the sediment assignments
        sediment_data_set = self.sediment_assignments.to_xarray()
        sediment_data_set = sediment_data_set.assign_coords({'index': [i for i in range(len(sediment_data_set.index))]})
        sediment_data_set.to_netcdf(self.main_file, group='/sediment_constituent_assignments', mode='a')

        # write the sediment diversions
        diversion_data_set = self.sediment_diversions.to_xarray()
        diversion_data_set = diversion_data_set.assign_coords(
            {'index': [i for i in range(len(diversion_data_set.index))]}
        )
        diversion_data_set.to_netcdf(self.main_file, group='/sediment_diversions', mode='a')

        # write the use transport dataset
        use_data_set = self.uses_transport.to_xarray()
        use_data_set = use_data_set.assign_coords({'index': [i for i in range(len(use_data_set.index))]})
        use_data_set.to_netcdf(self.main_file, group='/uses_transport', mode='a')

        # write the use sediment dataset
        use_sed_data_set = self.uses_sediment.to_xarray()
        use_sed_data_set = use_sed_data_set.assign_coords({'index': [i for i in range(len(use_sed_data_set.index))]})
        use_sed_data_set.to_netcdf(self.main_file, group='/uses_sediment', mode='a')

    def _commit_time_series(self):
        """Commits all of the time series."""
        # Clean out unused time series
        id_list = list(self.bc.time_series.keys())
        with h5py.File(self.main_file, 'a') as f:
            if '/time_series' in f:
                old_series_groups = f['/time_series'].keys()
                for group_name in old_series_groups:
                    if group_name != 'ids':
                        series_id = int(group_name)
                        if series_id not in id_list:
                            try:
                                del f[f'/time_series/{series_id}/']
                            except Exception:
                                pass

        # Write the time series ids
        if not id_list:
            id_list = []
        else:
            id_list = [[series_id] for series_id in id_list]
        self.series_ids = pd.DataFrame(id_list, columns=['ids']).to_xarray()
        self.series_ids.to_netcdf(self.main_file, group='/time_series/ids', mode='a')

        # Write all of the time series
        for series_id in self.bc.time_series.keys():
            if series_id and self.bc.time_series[series_id]:
                write_params_recursive(
                    self.main_file, group_name=f'/time_series/{series_id}/', param_class=self.bc.time_series[series_id]
                )

    def get_next_comp_id(self):
        """Gets the next component id by incrementing the variable.

        Returns:
            The next component id as an integer.
        """
        self.info.attrs['next_comp_id'] += 1
        return int(self.info.attrs['next_comp_id'])

    def get_next_flux_id(self):
        """Gets the next flux id by incrementing the variable.

        Returns:
            The next flux id as an integer.
        """
        self.info.attrs['next_flux_id'] += 1
        return int(self.info.attrs['next_flux_id'])

    def get_next_friction_id(self):
        """Gets the next friction id by incrementing the variable.

        Returns:
            The next friction id as an integer.
        """
        self.info.attrs['next_friction_id'] += 1
        return int(self.info.attrs['next_friction_id'])

    def get_next_transport_id(self):
        """Gets the next transport id by incrementing the variable.

        Returns:
            The next transport id as an integer.
        """
        self.info.attrs['next_transport_id'] += 1
        return int(self.info.attrs['next_transport_id'])

    def get_next_diversion_id(self):
        """Gets the next sediment diversion id by incrementing the variable.

        Returns:
            The next sediment diversion id as an integer.
        """
        self.info.attrs['next_diversion_id'] += 1
        return int(self.info.attrs['next_diversion_id'])

    def get_next_bc_id(self):
        """Gets the next boundary condition id by incrementing the variable.

        Returns:
            The next boundary condition id as an integer.
        """
        self.info.attrs['next_bc_id'] += 1
        return int(self.info.attrs['next_bc_id'])

    def clean(self, group_names):
        """Cleans out groups from the main file.

        Args:
            group_names (list): A list of string group names.
        """
        with h5py.File(self.main_file, 'a') as f:
            for group_name in group_names:
                try:
                    del f[group_name]
                except Exception:
                    pass

    def _fix_old_bc_cards(self):
        """Fixes obsolete BC card names."""
        df = self._bc.solution_controls
        # Update "NB VEC" cards to "NB OVL"
        df.loc[(df['CARD'] == 'NB') & (df['CARD_2'] == 'VEC'), 'CARD_2'] = 'OVL'

    def _read_bc_data(self):
        """Reads the boundary condition data from the main file."""
        self._bc = BoundaryConditions()
        read_params_recursive(self.main_file, group_name='/', param_class=self._bc)
        for id in self.series_ids['ids']:
            if id and id > 0 and not numpy.isnan(id):
                int_id = int(id)
                if int_id not in self._bc.time_series:
                    self._bc.time_series[int_id] = LazyTimeSeries(self.main_file, int_id)
        self._fix_old_bc_cards()

    def as_dict(self) -> dict:
        """
        Converts the object's attributes and related data into a JSON serializable dictionary format.

        Returns:
            dict: A dictionary containing the serialized and filtered attributes of
            the object.

        Raises:
            TypeError: If processing the attributes results in an invalid data type.
        """
        boundary_conditions = parameterized_to_dict(self.bc)

        # Filter info_attrs using the helper function
        info_attrs = filtered_info_attrs(self.info.attrs)

        # Use the helper functions to simplify the logic for DataFrames and Xarrays
        output = {
            "info_attrs": info_attrs,
            "arc_display": include_if_xarray_not_empty(self.arc_display),
            "pt_display": include_if_xarray_not_empty(self.pt_display),
            "series_ids": include_if_xarray_not_empty(self.series_ids),
            "comp_id_to_ids": include_dataframe_if_not_empty(self.comp_id_to_ids),
            "snap_types": include_dataframe_if_not_empty(self.snap_types),
            "flux": include_dataframe_if_not_empty(self.flux),
            "transport_assignments": include_dataframe_if_not_empty(self.transport_assignments),
            "uses_transport": include_dataframe_if_not_empty(self.uses_transport),
            "sediment_assignments": include_dataframe_if_not_empty(self.sediment_assignments),
            "uses_sediment": include_dataframe_if_not_empty(self.uses_sediment),
            "sediment_diversions": include_dataframe_if_not_empty(self.sediment_diversions),
            "nb_out": include_dataframe_if_not_empty(self.nb_out),
            "boundary_conditions": boundary_conditions,
        }

        output = make_json_serializable(output)
        return output
