"""SfrPackageBuilder class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import collections

# 2. Third party modules
import pandas as pd
import shapefile  # From pyshp

# 3. Aquaveo modules

# 4. Local modules
from xms.mf6.data.grid_info import DisEnum
from xms.mf6.mapping.package_builder_base import (
    CELLGRP, PackageBuilderBase, PackageBuilderInputs, save_periods, should_skip_record
)


def _ensure_sfr_fields_in_shapefile(reader):
    """Make sure certain fields are in the shapefile.

    Args:
        reader (shapefile.Reader): shapefile reader class
    """
    field_set = {field[0] for field in reader.fields}
    req_fields = {'ID', 'Node_1_ID', 'Node_2_ID', 'Segment_Nu'}
    if not req_fields.issubset(field_set):
        raise RuntimeError('Missing "ID", "Node_1_ID", "Node_2_ID, or "Segment_Nu" fields in shapefile.')


class SfrPackageBuilder(PackageBuilderBase):
    """Builds a package during map from coverage."""
    def __init__(self, inputs: PackageBuilderInputs):
        """Initializes the class.

        Args:
            inputs: Everything needed to build the package.
        """
        super().__init__(inputs)
        self._arc_cxns = {}  # [arc_id][0 or 1][list of arcs]. Connectivity of arcs. 0 = upstream, 1 = downstream
        self._arc_reaches = {}  # [arc_id][list of reaches]. Arcs and their list of reach numbers from up to down
        self._arc_reach_cellids = {}  # [arc_id][list of cellids]. Arcs and their list of cellids for each reach
        self._cxn_data = {}  # [rno][ic_list]. CONNECTIONDATA table - reach number and list of connected reaches
        self._diversion_arcs = {}
        self._diversions = {}  # DIVERSIONS table - reach (from), idv (diversion #), iconr (reach to)
        self._package_data = {}  # PACKAGEDATA table
        self._reader = None  # Shapefile reader
        self._records = None  # Shapefile records
        self._period_rows = {}  # dict of period -> sfrsetting rows
        self._layer_range_exists = False
        self._disu_layer_warning = False
        self._ugrid = None
        self._checked_cells = set()
        self._feature_type = 'arcs'

    def build(self):
        """Builds the package."""
        feature_type = 'arcs'
        self._set_up_for_current_coverage(self._cov_map_info_list[0])
        shapefile_name, att_type, ix_recs, map_info = self._get_initial_stuff(feature_type)
        self.trans_data = self._read_transient_data(list(map_info.keys()), shapefile_name)
        t_values = self._ix_info.arc_t_vals
        if self.grid_info.dis_enum == DisEnum.DISU:
            self._ugrid = self.cogrid.ugrid

        with shapefile.Reader(shapefile_name) as self._reader:
            _ensure_sfr_fields_in_shapefile(self._reader)
            self._layer_range_exists = self._does_layer_range_exist(self._reader)
            self._records = self._reader.records()

            self._build_arc_connections(att_type)
            self._build_arc_reaches(att_type, ix_recs)
            self._build_cxn_data()
            self._build_package_data(att_type, t_values)
            self._build_diversions()
            self._adjust_ustrf()
            self._build_periods()
            self._write_temporarily()
            self._save_groups_with_package()

    def _build_arc_connections(self, att_type):
        """Initializes self._arc_cxns which has the stream connectivity of the arcs.

        _arc_cxns{arc ID}[0 or 1][list of arcs]. 0 = upstream arcs, 1 = downstream arcs

        Args:
            att_type (str): String used in att table for the attribute Type column (i.e. 'well', 'drain', 'river' etc)
        """
        # Get arcs attached to nodes (node_arcs) and nodes of arcs (arc_nodes)
        node_arcs = {}  # Dict of node ID -> list of arc IDs (arcs attached to each node)
        arc_nodes = {}  # Dict of arc ID -> list of 2 nodes with upstream first, downstream second
        for record in self._records:

            if record['Type'] != att_type:
                continue

            node0 = record['Node_1_ID']
            node1 = record['Node_2_ID']
            arc_id = record['ID']
            diversion = record['Diversion_']
            if diversion:
                self._diversion_arcs[arc_id] = record

            # Fill node_arcs
            self._add_arc_to_node(arc_id, node0, node_arcs)
            self._add_arc_to_node(arc_id, node1, node_arcs)

            # Fill arc_nodes
            if arc_id not in arc_nodes:
                arc_nodes[arc_id] = [0, 0]
            arc_nodes[arc_id][0] = node0
            arc_nodes[arc_id][1] = node1

        # Build self._arc_cxns
        for arc_id in arc_nodes:
            if arc_id not in self._arc_cxns:
                self._arc_cxns[arc_id] = [[], []]

            # Append upstream and downstream arcs
            self._add_arcs_connected_to_node(arc_id, arc_nodes, node_arcs, upstream=True)
            self._add_arcs_connected_to_node(arc_id, arc_nodes, node_arcs, upstream=False)

    def _add_arcs_connected_to_node(self, arc_id, arc_nodes, node_arcs, upstream):
        """Adds to self._arc_cxns either the upstream or the downstream arcs of arc_id.

        Args:
            arc_id (int): ID of the arc.
            arc_nodes (dict): Dict of arc ID -> list of 2 nodes with upstream first, downstream second.
            node_arcs (dict): Dict of node ID -> list of arc IDs (arcs attached to each node).
            upstream (bool): True if doing upstream, False if downstream.
        """
        node_idx = 0 if upstream else 1
        adj_node_idx = 1 if upstream else 0
        node = arc_nodes[arc_id][node_idx]
        for node_arc in node_arcs[node]:
            if node_arc != arc_id and arc_nodes[node_arc][adj_node_idx] == node:
                self._arc_cxns[arc_id][node_idx].append(node_arc)

    def _add_arc_to_node(self, arc_id, node_id, node_arcs):
        """Adds the arc_id to the list of arcs attached to the node.

        Args:
            arc_id (int): ID of arc.
            node_id (int): ID of node.
            node_arcs (dict): Dict of node ID -> list of arc IDs (arcs attached to each node)
        """
        if node_id not in node_arcs:
            node_arcs[node_id] = [arc_id]
        else:
            node_arcs[node_id].append(arc_id)

    def _disu_active_cells_above_below(self, idx):
        """Update the index of the cell if using the DISU package.

        Args:
            idx (int): current cell index

        Returns:
            (list): the active cells above the cell
        """
        fc = self._ugrid.get_cell_3d_face_count(idx)
        top = self._ugrid.face_orientation_enum.ORIENTATION_TOP
        bot = self._ugrid.face_orientation_enum.ORIENTATION_BOTTOM
        cells_above = []
        cells_below = []
        for fidx in range(fc):
            face_orient = self._ugrid.get_cell_3d_face_orientation(idx, fidx)
            if face_orient == top:
                top_idx = self._ugrid.get_cell_3d_face_adjacent_cell(idx, fidx)
                if top_idx > -1:
                    if self._cell_is_active(top_idx):
                        cells_above.append(top_idx)
            elif face_orient == bot:
                bot_idx = self._ugrid.get_cell_3d_face_adjacent_cell(idx, fidx)
                if bot_idx > -1:
                    if self._cell_is_active(bot_idx):
                        cells_below.append(bot_idx)
        return cells_above, cells_below

    def _cell_is_active(self, idx):
        """Return true if the cell is active.

        Args:
            idx (int): current cell index

        Returns:
            (bool): True if the cell is active
        """
        if self._idomain and self._idomain[idx] == 0:
            return False
        return True

    def _update_cell_idx(self, idx):
        """Update the index of the cell if using the DISU package.

        Args:
            idx (int): current cell index

        Returns:
            (int): the updated cell index
        """
        if self._ugrid is None:
            return idx
        cell_active = self._cell_is_active(idx)
        cells_above, cells_below = self._disu_active_cells_above_below(idx)
        if cell_active and not cells_above:  # cell is active and no active cells above
            return idx
        elif not cells_above and not cells_below:  # no active cells above or below, cell may not be active
            return idx
        elif not cells_above:  # no active cells above then use the first cell below
            return self._update_cell_idx(cells_below[0])
        else:
            # is any above cell in self._checked_cells
            for cidx in cells_above:
                if cidx in self._checked_cells:
                    return cidx
            # have to check all cells_above, one of them may have another cell above
            cells_above_2 = [self._update_cell_idx(cidx) for cidx in cells_above]
            for idx1, idx2 in zip(cells_above, cells_above_2):
                if idx1 != idx2:
                    return self._update_cell_idx(idx2)
            return cells_above[0]

    def _build_arc_reaches(self, att_type, ix_recs):
        """Builds self._arc_reaches, the dict of arcs and their list of reach numbers.

        Args:
            att_type (str): String used in att table for the attribute Type column (i.e. 'well', 'drain', 'river' etc)
            ix_recs (list(recarray)): The intersection recarrays.
        """
        reach_number = 1

        for i, record in enumerate(self._records):
            if should_skip_record(ix_recs[i], record, att_type):
                continue  # Shape intersected no cells or the feature att type doesn't match. Skip it.
            cellids_0_based = ix_recs[i].cellids.tolist()
            lengths = ix_recs[i].lengths.tolist()

            arc_id = record['ID']

            from_lay = to_lay = 1
            if self._layer_range_exists:
                from_lay, to_lay = self.get_layer_range(record, self._layer_range_exists)
            if from_lay != to_lay:
                msg = f'Arc id: {arc_id} is assigned to multiple layers. SFR arcs may only be assigned ' \
                      f'to a single layer. The arc will be assigned to layer {from_lay}.'
                self.log.warning(msg)

            self._checked_cells = set()
            self._arc_reaches[arc_id] = []
            self._arc_reach_cellids[arc_id] = []
            for j, cellid_0_based in enumerate(cellids_0_based):
                cellid = self.grid_info.fix_cellid(cellid_0_based, from_lay)
                idx = idx_1 = self.grid_info.cell_index_from_modflow_cellid(cellid)
                idx = self._update_cell_idx(idx)  # may need to change cell if DISU
                if idx in self._checked_cells:
                    continue

                if idx != idx_1:  # if the index changed with DISU then update the cellid
                    cellid = idx + 1

                if not self._cell_is_active(idx):
                    continue
                self._arc_reach_cellids[arc_id].append((cellid, lengths[j]))
                self._arc_reaches[arc_id].append(reach_number)
                reach_number += 1
                self._checked_cells.add(idx)

    def _build_cxn_data(self):
        """Builds self._cxn_data, the CONNECTIONDATA table.

        self._cxn_data won't be sorted by reach yet.
        """
        for arc, reaches in self._arc_reaches.items():
            previous_reach = None
            for i, reach in enumerate(reaches):
                if reach not in self._cxn_data:
                    self._cxn_data[reach] = []

                if previous_reach:
                    self._cxn_data[reach].append(previous_reach)
                    self._cxn_data[previous_reach].append(-reach)

                if i == 0:
                    # First reach in list. Connect with upstream arcs
                    for upstream_arc in self._arc_cxns[arc][0]:
                        # Append upstream arc's last reach to current reaches list
                        upstream_arc_last_reach = self._arc_reaches[upstream_arc][-1]
                        self._cxn_data[reach].append(upstream_arc_last_reach)

                        # Append this reach to upstream arc's reach list as a negative number (downstream reach)
                        if upstream_arc_last_reach not in self._cxn_data:
                            self._cxn_data[upstream_arc_last_reach] = []
                        self._cxn_data[upstream_arc_last_reach].append(-reach)

                previous_reach = reach

    def _build_package_data(self, att_type, t_values):
        """Builds self._package_data, the PACKAGEDATA table.

        <rno> <cellid(ncelldim)> <rlen> <rwid> <rgrd> <rtp> <rbth> <rhk> <man> <ncon> <ustrf> <ndv>

        Args:
            att_type (str): String used in att table for the attribute Type column (i.e. 'well', 'drain', 'river' etc)
            t_values (list(float)): t values. 3D. 1st = # arcs. 2nd = # intersected cells. 3rd = # intersections
        """
        # Loop through each arc
        for i, record in enumerate(self._records):
            if record['Type'] != att_type:
                continue

            self._shape = self._reader.shape(i)
            arc_id = record['ID']

            # Get arc atts. Assumes no XY series
            width_diff = record['WIDTH2'] - record['WIDTH1']
            elev_diff = record['ELEVDN'] - record['ELEVUP']
            thick_diff = record['THICKM2'] - record['THICKM1']
            k_diff = record['HCOND2'] - record['HCOND1']

            # Loop through each intersected cell/reach
            for j, reach in enumerate(self._arc_reaches[arc_id]):
                t_value_avg = (t_values[i][j][0] + t_values[i][j][-1]) * 0.5  # Avg t_value from 1st and last
                # cellid = self.grid_info.fix_cellid(cellids[j])
                cellid, length = self._arc_reach_cellids[arc_id][j]
                rwid = record['WIDTH1'] + (t_value_avg * width_diff)  # width
                elev_up = record['ELEVUP'] + (t_values[i][j][0] * elev_diff)
                elev_down = record['ELEVUP'] + (t_values[i][j][-1] * elev_diff)  # 'ELEVUP' is correct here
                rgrd = (elev_up - elev_down) / length  # gradient
                rtp = (elev_up + elev_down) * 0.5  # top elevation of streambed
                rbth = record['THICKM1'] + (t_value_avg * thick_diff)  # streambed thickness
                rhk = record['HCOND1'] + (t_value_avg * k_diff)  # hydraulic conductivity of the streambed
                man = record['ROUGHCH']
                group = self.group(arc_id)
                boundname = record['Name']
                ustrf = 1.0  # For now, but may be changed in self._adjust_ustrf()
                self._package_data[reach] = {
                    'cellid': cellid,
                    'rlen': length,
                    'rwid': rwid,
                    'rgrd': rgrd,
                    'rtp': rtp,
                    'rbth': rbth,
                    'rhk': rhk,
                    'man': man,
                    'ncon': len(self._cxn_data[reach]),
                    'ustrf': ustrf,
                    'ndv': 0,
                    'cellgrp': group,
                    'boundname': boundname
                }

    def _build_diversions(self):
        """Builds self._diversions, the DIVERSIONS table.

        From the USGS documentation, it's not clear what ustrf should be for a diversion, but the one and only stream
        example that the USGS provides which has a diversion has ustrf of 0.0, so that's what we do. The actual flow
        to the diversion is specified in the PERIOD block using the sfrsetting "DIVERSION <idv> <divflow>".
        """
        # From GMS it comes out as 0,1,2,3 from the combo box, not 0,-1,-2,-3
        # iprior_to_cprior = {0: 'UPTO', -1: 'THRESHOLD', -2: 'FRACTION', -3: 'EXCESS'}
        iprior_to_cprior = {0: 'UPTO', 1: 'THRESHOLD', 2: 'FRACTION', 3: 'EXCESS'}
        for arc, record in self._diversion_arcs.items():
            for upstream_arc in self._arc_cxns[arc][0]:  # What happens if there's more than 1 upstream arc?
                reach = self._arc_reaches[upstream_arc][-1]  # Last reach of the upstream arc
                iconr = self._arc_reaches[arc][0]  # iconr is reach where water is diverted to. 1st reach of arc.
                cprior = iprior_to_cprior[record['IPRIOR']]
                if reach not in self._diversions:
                    self._diversions[reach] = []
                self._diversions[reach].append({'iconr': iconr, 'cprior': cprior, 'div_arc': arc})
                self._package_data[reach]['ndv'] += 1
                self._package_data[iconr]['ustrf'] = 0.0  # See docstring for why this is 0.0

    def _adjust_ustrf(self) -> None:
        """Adjust ustrf values, which requires self._cxn_data and self._diversions."""
        for reach, values in self._package_data.items():
            values['ustrf'] = self._compute_ustrf(reach)

    def _build_periods(self):
        """Builds the PERIODS blocks."""
        periods = list(self.trans_data.keys())
        arc_ids = list(self._arc_reaches.keys())
        keywords_and_fields = self._sfr_setting_keywords_and_fields()
        for period in periods:
            rows = []
            for arc_id in arc_ids:
                for keyword, field in keywords_and_fields.items():
                    last_value = 0.0
                    if field and period > 1:
                        last_value = self.trans_data[period - 1][arc_id][field]

                    if keyword == 'INFLOW':
                        # Do inflow to reaches. Find reaches at top of segments where there is no upstream reach
                        reach = self._arc_reaches[arc_id][0]  # Will be first reach of arc
                        if len(self._cxn_data[reach]) == 1 and self._cxn_data[reach][0] < 0:
                            value = self.trans_data[period][arc_id][field]
                            if value != last_value:
                                rows.append(f'{reach} {keyword} {value}')
                    elif keyword == 'DIVERSION':
                        reach = self._arc_reaches[arc_id][-1]  # The diversion will be the last reach of arc
                        if reach in self._diversions:
                            div_list = self._diversions[reach]
                            for i, item in enumerate(div_list):
                                div_arc = item['div_arc']
                                value = self.trans_data[period][div_arc]['FLOW']
                                if period > 1:
                                    last_value = self.trans_data[period - 1][div_arc]['FLOW']
                                if value != last_value:
                                    rows.append(f'{reach} {keyword} {i + 1} {value}')
                    elif field:
                        if field == 'ROUGHCH' and period < 2:
                            continue
                        value = self.trans_data[period][arc_id][field]
                        if value != last_value:
                            for reach in self._arc_reaches[arc_id]:
                                rows.append(f'{reach} {keyword} {value}')

            if rows:
                self._period_rows[period] = rows

    def _compute_ustrf(self, reach: int) -> float:
        """Compute and return ustrf for this reach.

        "ustrf - real value that defines the fraction of upstream flow from each upstream reach that is applied as
         upstream inflow to the reach. The sum of all USTRF values for all reaches connected to the same upstream
         reach must be equal to one and USTRF must be greater than or equal to zero."

        From the USGS documentation, it's not clear what ustrf should be for a diversion, but the one and only stream
        example that the USGS provides which has a diversion has ustrf of 0.0, so that's what we do. The actual flow
        to the diversion is specified in the PERIOD block using the sfrsetting "DIVERSION <idv> <divflow>".

        Args:
            reach: The reach number (RNO) from the PACKAGEDATA block (should be positive).

        Returns:
            See description.
        """
        # Get all upstream reaches of this reach
        up_reaches = [r for r in self._cxn_data[reach] if r > 0]

        # Count total downstream reaches of all upstream reaches, not including diversions
        down_reaches: set[int] = set()
        i_am_a_diversion = False
        for up_reach in up_reaches:
            for r in self._cxn_data[up_reach]:
                if r < 0 and r not in down_reaches:  # A new downstream reach
                    if not self._diverting_from_to(up_reach, abs(r)):
                        down_reaches.add(r)
                    elif abs(r) == reach:
                        i_am_a_diversion = True
                        break

        if i_am_a_diversion:
            return 0.0  # See docstring for why diversions get a ustrf of 0.0

        # Compute ustrf by dividing 1.0 by number of downstream reaches (each downstream reach will be the same)
        down_reaches_count = len(down_reaches)
        if down_reaches_count == 0:
            ustrf = 1.0
        else:
            ustrf = 1.0 / down_reaches_count
        return ustrf

    def _diverting_from_to(self, from_reach: int, to_reach: int) -> bool:
        """Return True if from_reach is diverting water to to_reach as a diversion.

        Args:
            from_reach: The reach that is possibly diverting water to to_reach.
            to_reach: The reach that is possibly getting diverted water from from_reach.

        Returns:
            See description.
        """
        if from_reach not in self._diversions:
            return False

        for diversion in self._diversions[from_reach]:
            if abs(diversion['iconr']) == abs(to_reach):
                return True
        return False

    def _sfr_setting_keywords_and_fields(self):
        """Returns a dict of sfrsetting -> field.

        Returns:
            (dict): sfrsetting -> field.
        """
        return {
            'STATUS': None,
            'MANNING': 'ROUGHCH',
            'STAGE': None,
            'INFLOW': 'FLOW',
            'RAINFALL': 'PPTSW',
            'EVAPORATION': 'ETSW',
            'RUNOFF': 'RUNOFF',
            'DIVERSION': None,
            'UPSTREAM_FRACTION': None,
            'AUXILIARY': None,
        }

    def _dataframe_from_package_data(self, ):
        """Creates and returns a Pandas dataframe from the dict.

        Returns:
            (Pandas.DataFrame): The dataframe
        """
        data = {'RNO': [key for key in self._package_data.keys()]}

        if self.grid_info.dis_enum == DisEnum.DIS:
            data['CELLID'] = [
                f"{item['cellid'][0]} {item['cellid'][1]} {item['cellid'][2]}" for item in self._package_data.values()
            ]
        elif self.grid_info.dis_enum == DisEnum.DISV:
            data['CELLID'] = [f"{item['cellid'][0]} {item['cellid'][1]}" for item in self._package_data.values()]
        else:
            data['CELLID'] = [item['cellid'] for item in self._package_data.values()]

        data['RLEN'] = [item['rlen'] for item in self._package_data.values()]
        data['RWID'] = [item['rwid'] for item in self._package_data.values()]
        data['RGRD'] = [item['rgrd'] for item in self._package_data.values()]
        data['RTP'] = [item['rtp'] for item in self._package_data.values()]
        data['RBTH'] = [item['rbth'] for item in self._package_data.values()]
        data['RHK'] = [item['rhk'] for item in self._package_data.values()]
        data['MAN'] = [item['man'] for item in self._package_data.values()]
        data['NCON'] = [item['ncon'] for item in self._package_data.values()]
        data['USTRF'] = [item['ustrf'] for item in self._package_data.values()]
        data['NDV'] = [item['ndv'] for item in self._package_data.values()]
        data[CELLGRP] = [item['cellgrp'] for item in self._package_data.values()]
        data['BOUNDNAME'] = [item['boundname'] for item in self._package_data.values()]
        return pd.DataFrame(data)

    def _dataframe_from_cxn_data(self):
        """Creates and returns a Pandas dataframe from the dict.

        Returns:
            (Pandas.DataFrame): The dataframe
        """
        sorted_dict = collections.OrderedDict(sorted(self._cxn_data.items()))
        data = {
            'RNO': [key for key in sorted_dict.keys()],
            'IC': [' '.join(map(str, ic_list)) for ic_list in sorted_dict.values()],
        }
        return pd.DataFrame(data)

    def _dataframe_from_diversions(self):
        """Creates and returns a Pandas dataframe from the list.

        Returns:
            (Pandas.DataFrame): The dataframe
        """
        reaches = list(self._diversions.keys())
        reaches.sort()
        data = {'RNO': [], 'IDV': [], 'ICONR': [], 'CPRIOR': []}
        for reach in reaches:
            list_div = self._diversions[reach]
            for i, item in enumerate(list_div):
                data['RNO'].append(reach)
                data['IDV'].append(i + 1)
                data['ICONR'].append(item['iconr'])
                data['CPRIOR'].append(item['cprior'])
        return pd.DataFrame(data)

    def _write_temporarily(self):
        """Writes the data to disk because we don't have (yet) any in-memory way to store data yet.

        The data will get rewritten later in PackageBuilder.
        """
        block = 'PACKAGEDATA'
        df = self._dataframe_from_package_data()
        temp_filename = self._package.dataframe_to_temp_file(block, df)
        self._package.list_blocks[block] = temp_filename

        block = 'CONNECTIONDATA'
        df = self._dataframe_from_cxn_data()
        temp_filename = self._package.dataframe_to_temp_file(block, df)
        self._package.list_blocks[block] = temp_filename

        block = 'DIVERSIONS'
        df = self._dataframe_from_diversions()
        temp_filename = self._package.dataframe_to_temp_file(block, df)
        self._package.list_blocks[block] = temp_filename

        save_periods(self._period_rows, self._package)
