"""BcData class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules
import orjson
import pandas as pd
import pkg_resources

# 3. Aquaveo modules

# 4. Local modules
from xms.srh.data.par.bc_data_param import BcDataParam
from xms.srh.data.par.par_util import orjson_dict_to_param_cls, param_cls_to_orjson_dict
from xms.srh.data.srh_coverage_data import SrhCoverageData


class BcData(SrhCoverageData):
    """Class for storing the SRH BC properties."""
    structures_list = ['Culvert HY-8', 'Culvert', 'Weir', 'Pressure', 'Gate', 'Link']
    display_list = [
        'inlet_q', 'exit_h', 'exit_q', 'inlet_sc', 'exit_ex', 'wall', 'symmetry', 'internal_sink',
        ('hy8_culvert_upstream', 'hy8_culvert_downstream'), ('culvert_upstream', 'culvert_downstream'),
        ('weir_upstream', 'weir_downstream'), ('pressure_upstream', 'pressure_downstream'),
        ('gate_upstream', 'gate_downstream'), ('link_upstream', 'link_downstream'), 'bc_data'
    ]
    display_labels = [
        'Inlet-Q', 'Exit-H', 'Exit-Q', 'Inlet-SC', 'Exit-EX', 'Wall', 'Symmetry', 'Internal sink',
        ('HY-8 culvert upstream', 'HY-8 culvert downstream'), ('Culvert upstream', 'Culvert downstream'),
        ('Weir upstream', 'Weir downstream'), ('Pressure upstream', 'Pressure downstream'),
        ('Gate upstream', 'Gate downstream'), ('Link upstream', 'Link downstream'), 'Bc Data'
    ]

    def __init__(self, filename):
        """Constructor.

        Args:
            filename (:obj:`str`): file name
        """
        # table of component ids to bc ids. Can't just use bc ids because structures are groups that share a bc
        self._component_id_bc_id = None
        self._bc_data = None  # bc properties dataset
        self._time_series = None  # time series dataset
        super().__init__(filename, 'SRH_BC_DATA', 'structures')

    def _on_load_all(self):
        """Loads all datasets from the file."""
        _ = self.comp_ids
        _ = self.bc_data
        _ = self.time_series

    @property
    def comp_ids(self):
        """Get the _component_id_bc_id dataset.

        Returns:
            (:obj:`xarray.Dataset`): The _component_id_bc_id dataset
        """
        if self._component_id_bc_id is None:
            self._component_id_bc_id = self.get_dataset('component_id_bc_id', False)
            if self._component_id_bc_id is None:
                self._component_id_bc_id = self._default_component_id_bc_id()
        return self._component_id_bc_id

    def _default_component_id_bc_id(self):
        """Creates a default _component_id_bc_id data set.

        Returns:
            (:obj:`xarray.Dataset`): The _component_id_bc_id dataset
        """
        df = pd.DataFrame({'id': [0], 'bc_id': [0], 'display': ['wall']})
        return df.to_xarray()

    def set_bc_id_display_with_comp_id(self, bc_id, display, comp_id):
        """Sets the bc_id associate with comp_id.

        Args:
             bc_id (:obj:`int`): bc id (id in the bc_data dataset)
             display (:obj:`str`): display list category
             comp_id (:obj:`int`): component id

        """
        df = self.comp_ids.to_dataframe()
        row = len(df)
        df.loc[row, 'id'] = comp_id
        df.loc[row, 'bc_id'] = bc_id
        df.loc[row, 'display'] = display
        self._component_id_bc_id = df.to_xarray()

    @property
    def structures(self):
        """Get the _structures dataset.

        Returns:
            (:obj:`xarray.Dataset`): The _structures dataset
        """
        return self.data

    def set_structure_with_bc_id(self, structure, bc_id):
        """Sets a record with an id. If the id is < 1 then do nothing.

        Args:
            structure (:obj:`BcDataArcs`): instance of BcDataArcs class
            bc_id (:obj:`int`): bc_id

        """
        if bc_id < 1:
            return
        pdict = param_cls_to_orjson_dict(cls=structure, df_handler=None, skip_negative_precedence=False)
        json_txt = orjson.dumps(pdict).decode()
        df = self._data.to_dataframe()
        row = len(df)
        df.loc[row, 'id'] = bc_id
        df.loc[row, 'json'] = json_txt
        self._data = df.to_xarray()

    def structure_param_from_id(self, bc_id):
        """Gets a record from an id. If the id is not in the structure then returns None.

        Args:
            bc_id (:obj:`int`): bc_id

        Returns:
            (:obj:`BcDataArcs`): The record from the structures dataframe as a BcDataArcs class
        """
        bc = BcDataParam()
        return self.param_from_id(bc_id, bc.arcs)

    @property
    def bc_data(self):
        """Get the bc dataset.

        Returns:
            (:obj:`xarray.Dataset`): The bc_data list dataset
        """
        if self._bc_data is None:
            # may need to migrate BC Data
            self._bc_data = self.get_dataset('bc_data', False)
            if self._bc_data is None:
                self._bc_data = self._default_bc_data()
            else:
                self._migrate_bc_data()
        return self._bc_data

    def _default_bc_data(self):
        """Creates a default bc data set.

        Returns:
            (:obj:`xarray.Dataset`): The bc dataset
        """
        default_data = {'id': [0], 'bc_type': 'Wall (no-slip boundary)', 'bc_json': ''}
        return pd.DataFrame(default_data).to_xarray()

    def bc_id_from_comp_id(self, c_id):
        """Gets a record from an id. If the id is not in the bc_data then returns None.

        Args:
            c_id (:obj:`int`): component id

        Returns:
            (:obj:`int`): the bc_id
        """
        df = self.comp_ids.to_dataframe()
        bc_id = -1
        rec = df.loc[df['id'] == c_id].reset_index(drop=True).to_dict()
        if 'bc_id' in rec and 0 in rec['bc_id']:
            bc_id = rec['bc_id'][0]
        return int(bc_id)

    def bc_data_param_from_id(self, c_id):
        """Gets a record from an id. If the id is not in the bc_data then returns None.

        Args:
            c_id (:obj:`int`): component id

        Returns:
            (:obj:`BcDataParam`): The record from the bc dataframe as a BcDataParam class
        """
        record = self._bc_record_from_id(c_id)
        bc = BcDataParam()
        if record is None:
            return bc
        bc.bc_type = record['bc_type'][0]
        orjson_txt = record['bc_json'][0]
        if orjson_txt:
            pdict = orjson.loads(orjson_txt.encode())
            orjson_dict_to_param_cls(pdict, bc, df_handler=self._time_series_dataframe_from_id)
            bc.update_bc_type()  # update the precedence after reading the data
        return bc

    def set_bc_data_with_id(self, bc_data, c_id):
        """Sets a record with an id. If the id is < 1 then do nothing.

        Args:
            bc_data (:obj:`BcDataParam`): instance of bc data
            c_id (:obj:`int`): component id

        """
        if c_id < 1:
            return
        df = self.bc_data.to_dataframe()
        series = df[df['id'] == c_id]
        if len(series) < 1:
            row = len(df)
        else:
            row = series.index[0]
        df.loc[row, 'id'] = c_id
        df.loc[row, 'bc_type'] = bc_data.bc_type

        # don't store arcs here
        self._prepare_bc_data_for_json_dump(bc_data)
        arcs_precendence = bc_data.param.arcs.precedence
        bc_data.param.arcs.precedence = -1

        pdict = param_cls_to_orjson_dict(cls=bc_data, df_handler=self._set_time_series_from_dataframe)

        bc_data.param.arcs.precedence = arcs_precendence
        json_bc = orjson.dumps(pdict).decode()
        df.loc[row, 'bc_json'] = json_bc
        self._bc_data = df.to_xarray()

    def _prepare_bc_data_for_json_dump(self, bc_data):
        """Changes the precedence on certain data members for saving to a json string.

        Args:
            bc_data (:obj:`BcDataParam`): instance of bc data

        """
        type_to_enable = {
            'Inlet-Q (subcritical inflow)': bc_data.inlet_q,
            'Exit-H (subcritical outflow)': bc_data.exit_h,
            'Pressure': bc_data.pressure,
            'Internal sink': bc_data.internal_sink,
            'Link': bc_data.link,
        }
        if bc_data.bc_type in type_to_enable:
            type_to_enable[bc_data.bc_type].enabler._restore_all_precedence()
        if bc_data.bc_type == 'Exit-H (subcritical outflow)':
            p = bc_data.exit_h.param.manning_n_calculator_inputs.precedence
            bc_data.exit_h.param.manning_n_calculator_inputs.precedence = abs(p)
        if bc_data.bc_type == 'Culvert HY-8':
            bc_data.hy8_culvert.param.crossing_selector.precedence = -5
            bc_data.hy8_culvert.param.launch_hy8.precedence = -6
            bc_data.hy8_culvert.param.hy8_crossing_guid.precedence = 7
        # update BCDATA lines
        if bc_data.param.bc_data_lines.precedence > 0:
            bc_data.bc_data_lines.param.upstream_line.precedence = -3
            if bc_data.bc_data_lines.upstream_line:
                bc_data.bc_data_lines.param.upstream_line_label.precedence = 2
            bc_data.bc_data_lines.param.downstream_line.precedence = -6
            if bc_data.bc_data_lines.downstream_line:
                bc_data.bc_data_lines.param.downstream_line_label.precedence = 5

    def _bc_record_from_id(self, c_id):
        """Gets a record from an id. If the id is not in the bc_data then returns None.

        Args:
            c_id (:obj:`int`): component id

        Returns:
            (:obj:`dict`): The record in the dataframe
        """
        df = self.bc_data.to_dataframe()
        r = df.loc[df['id'] == c_id]
        if len(r) > 0:
            record = r.reset_index(drop=True).to_dict()
            return record
        return None

    @property
    def time_series(self):
        """Get the time series dataset.

        Returns:
            (:obj:`xarray.Dataset`): The time series dataset

        """
        if self._time_series is None:
            self._time_series = self.get_dataset('time_series', False)
            if self._time_series is None:
                self._time_series = self._default_time_series()
        return self._time_series

    def _default_time_series(self):
        """Creates a default time series data set.

        Returns:
            (:obj:`xarray.Dataset`): The time series dataset
        """
        default_data = {'id': [0], 'x': [0.0], 'y': [0.0]}
        return pd.DataFrame(default_data).to_xarray()

    def _time_series_dataframe_from_id(self, c_id):
        """Creates a dataframe by extracting from the time_series dataset.

        Args:
            c_id (:obj:`int`): series id

        Returns:
            (:obj:`DataFrame`): resulting dataframe
        """
        if c_id < 0:
            return None
        df = self.time_series.to_dataframe()
        df1 = df.loc[df['id'] == c_id].copy()
        if len(df1) > 0:
            df1.drop(['id'], axis=1, inplace=True)
        return df1

    def _set_time_series_from_dataframe(self, df):
        """Append df to the time_series dataset and give it a new id.

        Args:
            df (:obj:`pandas.DataFrame`): x, y series

        Returns:
            (:obj:`int`): id of the time series
        """
        if len(df) < 1:
            return -1
        ts_df = self.time_series.to_dataframe()
        new_id = max(ts_df['id']) + 1
        df['id'] = new_id
        df.rename(columns={df.columns[0]: 'x', df.columns[1]: 'y'}, inplace=True)
        ts_df = pd.concat([ts_df, df], ignore_index=True, sort=False)
        self._time_series = ts_df.to_xarray()
        return new_id

    def commit(self):
        """Save in memory datasets to the NetCDF file."""
        super().commit()

        self._drop_h5_groups(['component_id_bc_id', 'bc_data', 'time_series'])
        # write
        if self._component_id_bc_id is not None:
            self._component_id_bc_id.to_netcdf(self._filename, group='component_id_bc_id', mode='a')
        if self._bc_data is not None:
            self._bc_data.to_netcdf(self._filename, group='bc_data', mode='a')
        if self._time_series is not None:
            self._time_series.to_netcdf(self._filename, group='time_series', mode='a')

    def close(self):
        """Closes the H5 file and does not write any data that is in memory."""
        super().close()
        if self._component_id_bc_id is not None:
            self._component_id_bc_id.close()
        if self._bc_data is not None:
            self._bc_data.close()
        if self._time_series is not None:
            self._time_series.close()

    def remove_comp_ids(self, comp_ids):
        """Removes comp ids from the class that are not used any more."""
        if comp_ids is None:
            return
        if len(comp_ids) < 1:
            return
        comp_df = self._component_id_bc_id.to_dataframe()
        # get bc ids to delete
        df = comp_df[comp_df['id'].isin(comp_ids)]
        bc_ids = set(df['bc_id'].to_list())

        df = comp_df[~comp_df['id'].isin(comp_ids)].reset_index(drop=True)
        self._component_id_bc_id = df.to_xarray()

        # structures
        df = self._data.to_dataframe()
        df = df[~df['id'].isin(bc_ids)].reset_index(drop=True)
        self._data = df.to_xarray()

        df = self._bc_data.to_dataframe()
        df = df[~df['id'].isin(bc_ids)].reset_index(drop=True)
        self._bc_data = df.to_xarray()

    def _migrate_bc_data(self):
        """Migrate the bc data from old versions."""
        file_version = self.info.attrs['VERSION']
        if pkg_resources.parse_version(file_version) < pkg_resources.parse_version('1.1.0'):
            # the definition of a link weir changed and we need to migrate any old ones
            updated_bcs = []
            df = self._bc_data.to_dataframe()
            for _, row in df.iterrows():
                bc_id = row['id']
                orjson_txt = row['bc_json']
                if orjson_txt:
                    # Found projects with garbage characters at the end.
                    # This strips off the bad characters before loading.
                    orjson_txt = orjson_txt[:orjson_txt.rfind('}') + 1]
                    pdict = orjson.loads(orjson_txt.encode())
                    migrate_bc = pdict['bc_type'] == 'Link' and pdict['link']['inflow_type'] == 'Weir'
                    if migrate_bc and 'coeff' in pdict['link']['weir']:
                        bc = BcDataParam()
                        orjson_dict_to_param_cls(pdict, bc, df_handler=self._time_series_dataframe_from_id)
                        # set weirtype to 'User'. This is how srh_pre reads old files
                        bc.link.weir.type.type = 'User'
                        bc.link.weir.type.cw = pdict['link']['weir']['coeff']
                        # set a and b to '16.4' and '0.432' values come from paved.
                        bc.link.weir.type.a = 16.4
                        bc.link.weir.type.b = 0.432
                        updated_bcs.append((bc_id, bc))
            for item in updated_bcs:
                self.set_bc_data_with_id(item[1], item[0])
