"""Stores the data for the weirs, flap gates, and sluice gates."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import importlib.metadata
import os

# 2. Third party modules
from adhparam.boundary_conditions import BoundaryConditions
from adhparam.time_series import TimeSeries
import h5py
import numpy
import pandas as pd
import xarray as xr

# 3. Aquaveo modules
from xms.components.bases.xarray_base import XarrayBase

# 4. Local modules
from xms.adh.data.param_h5_io import read_params_recursive, write_params_recursive


class StructuresIO(XarrayBase):
    """A class that handles saving and loading structure (weir, flap gate, sluice gate) data."""
    def __init__(self, main_file):
        """Initializes the data class.

        Args:
            main_file: The main file associated with this component.
        """
        super().__init__(main_file)
        self.main_file = main_file
        attrs = self.info.attrs
        defaults = self.default_data()
        if not os.path.exists(main_file) or not os.path.isfile(main_file):
            self._info = xr.Dataset(attrs=defaults)
        self.arc_display = xr.Dataset()
        self.pt_display = xr.Dataset()
        self.series_ids = xr.Dataset()
        self.comp_id_columns = ['COMP_ID', 'FRICTION_ID', 'FLUX_ID', 'TRANSPORT_ID', 'DIVERSION_ID']
        self.comp_id_to_ids = pd.DataFrame([], columns=self.comp_id_columns)

        self.weir_names = pd.DataFrame([], columns=['ID', 'NAME'])
        self.flap_names = pd.DataFrame([], columns=['ID', 'NAME'])
        self.sluice_names = pd.DataFrame([], columns=['ID', 'NAME'])
        for key in defaults.keys():
            if key not in attrs:
                attrs[key] = defaults[key]
        attrs['FILE_TYPE'] = 'ADH_STRUCTURE_DATA'
        self.bc = BoundaryConditions()
        if os.path.exists(main_file):
            self.read_from_h5_file(main_file, 'ADH_STRUCTURE_DATA')

    def read_from_h5_file(self, filename, file_type):
        """Reads this class from an h5 file.

        Args:
            filename (str): file name
            file_type (str): The type of file as written in the netCDF file.
        """
        grp_name = 'info'
        info = xr.load_dataset(filename, group=grp_name)
        if 'FILE_TYPE' not in info.attrs:
            raise IOError(f'File attributes do not include file_type in file : {filename}')
        ftype = info.attrs['FILE_TYPE']
        if ftype != file_type:
            raise IOError(f'Unrecognized file_type "{ftype}" attribute in file: {filename}')

        read_params_recursive(filename, group_name='/', param_class=self.bc)
        self.read_name_and_display_from_h5(filename)
        self.read_structure_names_from_h5(filename)
        self.read_time_series_from_h5(filename)
        self.comp_id_to_ids = xr.load_dataset(filename, group='/comp_id_to_ids').to_dataframe()

    def read_name_and_display_from_h5(self, filename):
        """Reads names and display options of a category from a file.

        Args:
            filename (str): The name of the file to read from.
        """
        arc_grp_name = '/display/arc'
        pt_grp_name = '/display/point'
        # name, red, green, blue, alpha, style, size
        self.arc_display = xr.load_dataset(filename, group=arc_grp_name)
        self.pt_display = xr.load_dataset(filename, group=pt_grp_name)

    def read_structure_names_from_h5(self, filename):
        """Reads the names of weirs, flap gates, and sluice gates from a file.

        Args:
            filename (str): the netCDF file to read names from.
        """
        weir_grp_name = '/structures/weirs'
        flap_grp_name = '/structures/flap_gates'
        sluice_grp_name = '/structures/sluice_gates'
        # name, red, green, blue, alpha, style, size
        self.weir_names = xr.load_dataset(filename, group=weir_grp_name).to_dataframe()
        self.flap_names = xr.load_dataset(filename, group=flap_grp_name).to_dataframe()
        self.sluice_names = xr.load_dataset(filename, group=sluice_grp_name).to_dataframe()

    def read_time_series_from_h5(self, filename):
        """Reads the time series from the netCDF file.

        Args:
            filename (str): The name of the netCDF file to read time series from.
        """
        series_ids = '/time_series/ids'
        self.series_ids = xr.load_dataset(filename, group=series_ids)
        for id in self.series_ids['ids']:
            if id and id > 0 and not numpy.isnan(id):
                int_id = int(id)
                if int_id not in self.bc.time_series:
                    self.bc.time_series[int_id] = TimeSeries()
                read_params_recursive(
                    filename, group_name=f'/time_series/{int_id}/', param_class=self.bc.time_series[int_id]
                )

    @staticmethod
    def default_data(self):
        """Gets the default data for the info attributes."""
        version = importlib.metadata.version('xmsadh')
        return {
            'FILE_TYPE': 'ADH_STRUCTURE_DATA',
            'VERSION': version,
            'arc_display_uuid': '',
            'point_display_uuid': '',
            'cov_uuid': '',
            'next_comp_id': 0
        }

    def commit(self):
        """Saves data to the main file."""
        self.info.attrs['VERSION'] = importlib.metadata.version('xmsadh')
        super().commit()
        write_params_recursive(self.main_file, group_name='/', param_class=self.bc)
        self.arc_display.to_netcdf(self.main_file, group='/display/arc', mode='a')
        self.pt_display.to_netcdf(self.main_file, group='/display/point', mode='a')

        self.clean(
            [
                '/comp_id_to_ids', '/time_series/ids', '/structures/weirs', '/structures/flap_gates',
                '/structures/sluice_gates'
            ]
        )

        comp_id_data_set = self.comp_id_to_ids.to_xarray()

        comp_id_data_set = comp_id_data_set.assign_coords({'index': [i for i in range(len(comp_id_data_set.index))]})

        comp_id_data_set.to_netcdf(self.main_file, group='/comp_id_to_ids', mode='a')

        self._commit_structures()
        self._commit_time_series()

    def _commit_structures(self):
        """Commits datasets related to weirs, flap gates, and sluice gates."""
        weir_data_set = self.weir_names.to_xarray()
        flap_data_set = self.flap_names.to_xarray()
        sluice_data_set = self.sluice_names.to_xarray()

        weir_data_set = weir_data_set.assign_coords({'index': [i for i in range(len(weir_data_set.index))]})
        flap_data_set = flap_data_set.assign_coords({'index': [i for i in range(len(flap_data_set.index))]})
        sluice_data_set = sluice_data_set.assign_coords({'index': [i for i in range(len(sluice_data_set.index))]})

        weir_data_set.to_netcdf(self.main_file, group='/structures/weirs', mode='a')
        flap_data_set.to_netcdf(self.main_file, group='/structures/flap_gates', mode='a')
        sluice_data_set.to_netcdf(self.main_file, group='/structures/sluice_gates', mode='a')

    def _commit_time_series(self):
        """Commits all of the time series."""
        # Clean out unused time series
        id_list = list(self.bc.time_series.keys())
        with h5py.File(self.main_file, 'a') as f:
            if '/time_series' in f:
                old_series_groups = f['/time_series'].keys()
                for group_name in old_series_groups:
                    if group_name != 'ids':
                        series_id = int(group_name)
                        if series_id not in id_list:
                            try:
                                del f[f'/time_series/{series_id}/']
                            except Exception:
                                pass

        # Write the time series ids
        if not id_list:
            id_list = []
        else:
            id_list = [[series_id] for series_id in id_list]
        self.series_ids = pd.DataFrame(id_list, columns=['ids']).to_xarray()
        self.series_ids.to_netcdf(self.main_file, group='/time_series/ids', mode='a')

        # Write all of the time series
        for series_id in self.bc.time_series.keys():
            if series_id and self.bc.time_series[series_id]:
                write_params_recursive(
                    self.main_file, group_name=f'/time_series/{series_id}/', param_class=self.bc.time_series[series_id]
                )

    def get_next_comp_id(self):
        """Gets the next component id by incrementing the variable.

        Returns:
            The next component id as an integer.
        """
        self.info.attrs['next_comp_id'] += 1
        return int(self.info.attrs['next_comp_id'])

    def clean(self, group_names):
        """Cleans out groups from the main file.

        Args:
            group_names (list): A list of string group names.
        """
        with h5py.File(self.main_file, 'a') as f:
            for group_name in group_names:
                try:
                    del f[group_name]
                except Exception:
                    pass

    @staticmethod
    def _clean_structure_sets(weir_up, weir_down, flap_up, flap_down, sluice_up, sluice_down):
        """Cleans out ids used in multiple structures from the structure sets.

        It adds the found multiples into multiple.

        Args:
            weir_up (:obj:`set` of int): The ids in the weir upstream category.
            weir_down (:obj:`set` of int): The ids in the weir downstream category.
            flap_up (:obj:`set` of int): The ids in the flap gate upstream category.
            flap_down (:obj:`set` of int): The ids in the flap gate downstream category.
            sluice_up (:obj:`set` of int): The ids in the sluice gate upstream category.
            sluice_down (:obj:`set` of int): The ids in the sluice gate downstream category.

        Returns:
            multiple (:obj:`set` of int): The ids used in multiple structures.
        """
        multiple = set()
        multiple.update(weir_up & weir_down)
        multiple.update(flap_up & flap_down)
        multiple.update(sluice_up & sluice_down)
        all_weir = weir_up | weir_down
        all_flap = flap_up | flap_down
        all_sluice = sluice_up | sluice_down
        multiple.update(all_weir & all_flap)
        multiple.update(all_weir & all_sluice)
        multiple.update(all_flap & all_sluice)
        weir_up.difference_update(multiple)
        weir_down.difference_update(multiple)
        flap_up.difference_update(multiple)
        flap_down.difference_update(multiple)
        sluice_up.difference_update(multiple)
        sluice_down.difference_update(multiple)
        return multiple

    @staticmethod
    def get_structure_string_ids(df, upstream_arc_col, downstream_arc_col):
        """Get the string ids for the structure.

        Args:
            df (pd.DataFrame): The structure dataframe.
            upstream_arc_col (str): The column text for the upstream arc column.
            downstream_arc_col (str): The column text for the downstream arc column.

        Returns:
            A tuple containing:
                - up_arc (:obj:`set` of int): Upstream arc string/component ids.
                - down_arc (:obj:`set` of int): Downstream arc string/component ids.
                - up_point (:obj:`set` of int): Upstream point string/component ids.
                - down_point (:obj:`set` of int): Downstream point string/component ids.
        """
        up_arc = set()
        down_arc = set()
        up_point = set()
        down_point = set()
        for row_data in df.itertuples():
            up_point.add(int(row_data.S_UPSTREAM))
            down_point.add(int(row_data.S_DOWNSTREAM))
            up_arc.add(int(getattr(row_data, upstream_arc_col)))
            down_arc.add(int(getattr(row_data, downstream_arc_col)))
        return up_arc, down_arc, up_point, down_point
