"""Writes the AdH simulation input files by first converting into an adhparam object."""

__copyright__ = '(C) Copyright Aquaveo 2025'
__license__ = 'All rights reserved'

# 1. Standard Python modules
import collections
import copy
import logging
import math
from typing import Optional, TYPE_CHECKING

# 2. Third party modules
from adhparam.hot_start_data_set import HotStartDataSet
from adhparam.material_transport_properties import MaterialTransportProperties
from adhparam.simulation import AdhSimulation
import pandas as pd
import rtree

# 3. Aquaveo modules
from xms.api.dmi import XmsEnvironment as XmEnv
from xms.constraint import Grid
from xms.constraint.ugrid_builder import UGridBuilder
from xms.data_objects.parameters import Coverage, FilterLocation
from xms.grid.ugrid import UGrid
from xms.guipy.data.target_type import TargetType
from xms.snap.snap_exterior_arc import SnapExteriorArc
from xms.snap.snap_interior_arc import SnapInteriorArc
from xms.snap.snap_point import SnapPoint
from xms.snap.snap_polygon import SnapPolygon

# 4. Local modules
from xms.adh.components.bc_conceptual_component import BcConceptualComponent
from xms.adh.components.material_conceptual_component import MaterialConceptualComponent
from xms.adh.components.sediment_material_conceptual_component import SedimentMaterialConceptualComponent
from xms.adh.components.transport_constituents_component import TransportConstituentsComponent
from xms.adh.data.model_control import ModelControl
from xms.adh.data.sediment_materials_io import SedimentMaterialsIO
from xms.adh.data.xms_test_data import save_xms_data
from xms.adh.mapping.snap_node_string import SnapNodeString

if TYPE_CHECKING:
    from xms.adh.data.xms_data import XmsData


class XmsDataToAdhParam:
    """Used to convert a simulation to an AdH param object."""

    def __init__(self, xms_data: 'XmsData', logger: Optional[logging.Logger] = None):
        """Initializes the class."""
        self.sim = AdhSimulation()
        self.sim_name = xms_data.sim_name
        self._grid = xms_data.adh_data.co_grid
        self._sim_comp = xms_data.sim_component
        self._bc_cov = xms_data.bc_coverage
        self._bc_comp = xms_data.bc_component
        self._mat_cov = xms_data.material_coverage
        self._mat_comp = xms_data.material_component
        self._output_cov = xms_data.output_coverage
        self._sed_mat_cov = xms_data.sediment_material_coverage
        self._sed_mat_comp = xms_data.sediment_material_component
        self._trans_comp = xms_data.transport_component
        self._sed_trans_comp = xms_data.sediment_constituents_component
        self._hotstart_datsets = {}
        self.hotstarts = []
        self._hot_start_names = {}
        self.general_constituents = None
        self.num_transport = 0
        self.off_material_string_id = 0
        self.flux_cards = []
        self.xms_data = xms_data
        self._logger = logger
        if self._logger is None:
            self._logger = logging.getLogger('xms.adh')

    @property
    def sim_data(self) -> ModelControl:
        """Gets the simulation component."""
        return self._sim_comp.data

    @property
    def bc_comp(self) -> BcConceptualComponent | None:
        """Gets the boundary conditions component."""
        return self._bc_comp

    @property
    def mat_comp(self) -> MaterialConceptualComponent | None:
        """Gets the material component."""
        return self._mat_comp

    @property
    def sed_mat_comp(self) -> SedimentMaterialConceptualComponent | None:
        """Gets the sediment material component."""
        return self._sed_mat_comp

    @property
    def trans_comp(self) -> TransportConstituentsComponent | None:
        """Gets the transport constituents component."""
        return self._trans_comp

    @property
    def sed_trans_comp(self):
        """Gets the sediment transport constituents component."""
        return self._sed_trans_comp

    def _get_hotstart(self):
        """Gets the hotstart datasets."""
        hot_starts = self.sim_data.get_hot_starts()
        for key, item in hot_starts.items():
            if key in self._hot_start_names:
                name = self._hot_start_names[key]
            else:
                name = key
            self._add_dataset(item[0], name)

    def _add_dataset(self, dset_uuid: str, name: str):
        """Query SMS for a dataset by the datasets uuid and add to map using dataset name as the key if it exists.

        Args:
            dset_uuid: UUID of the dataset to add.
            name: The name of the dataset to add.
        """
        dataset = self.xms_data.get_dataset_from_uuid(dset_uuid)
        if dataset:
            self._hotstart_datsets[name] = dataset

    def valid_inputs(self) -> bool:
        """Check for valid inputs for saving data.

        Returns:
            False if invalid inputs to save data.
        """
        valid = True
        if self._grid is None:
            valid = False
            self._logger.error('No grid provided.')
        if self._bc_cov is None or self._bc_comp is None:
            valid = False
            self._logger.error('No boundary conditions coverage provided.')
        return valid

    def get_data(self):
        """Populates a Param AdH model object with data retrieved from SMS."""
        try:
            self._get_hotstart()
        except Exception:
            self._logger.error('Unable to retrieve AdH data from XMS for exporting.')
            raise
        finally:
            if self.xms_data and XmEnv.xms_environ_running_tests() == 'TRUE':
                self.xms_data.send()

    def get_adh_param(self):
        """Gets the adhparam object with appropriate data.

        Returns:
            self.sim (adhparam.AdhSimulation): The simulation to write out.
            self.sim_name (str): The name of the simulation.
            self.hotstarts (list): A list of adhparam.HotStartDataSet.
        """
        transport_constituent_old_id_to_new_id = {}
        sediment_constituent_old_id_to_new_id = {}
        string_cards = []
        next_string_id = 1
        mat_id_to_string_id = {}
        sed_mat_id_to_string_id = {}
        bc_id_to_string_id = {}
        comp_id_to_bc_id = {}
        bc_fric_arc_id_to_string_id = {}
        bc_transport_id_to_string_id = collections.defaultdict(lambda: collections.defaultdict(list))
        bc_sediment_transport_id_to_string_id = collections.defaultdict(lambda: collections.defaultdict(list))
        bc_diversion_id_to_string_id = collections.defaultdict(list)
        mat_id_to_mat_sed: dict[int, tuple[int, int]] = {}
        mat_cell, next_string_id = snap_materials(
            self._grid, self.mat_comp, self.sed_mat_comp, self._mat_cov, self._sed_mat_cov, mat_id_to_string_id,
            sed_mat_id_to_string_id, next_string_id, string_cards, mat_id_to_mat_sed
        )
        self.sim.boundary_conditions.mat_id_to_mat_sed = mat_id_to_mat_sed

        if 0 in mat_cell:
            # Make a smaller mesh with material off zones removed
            mat_cell, next_string_id, active_grid = self._create_active_grid(
                mat_cell, mat_id_to_string_id, next_string_id, string_cards
            )
        else:
            active_grid = self._grid

        # Get component to arc/point ids
        if self.bc_comp and self.bc_comp.cov_uuid in self.bc_comp.comp_to_xms:
            if TargetType.arc in self.bc_comp.comp_to_xms[self.bc_comp.cov_uuid]:
                arc_comp_ids = self.bc_comp.comp_to_xms[self.bc_comp.cov_uuid][TargetType.arc]
            else:
                arc_comp_ids = {}
            if TargetType.point in self.bc_comp.comp_to_xms[self.bc_comp.cov_uuid]:
                pt_comp_ids = self.bc_comp.comp_to_xms[self.bc_comp.cov_uuid][TargetType.point]
            else:
                pt_comp_ids = {}
        else:
            arc_comp_ids = {}
            pt_comp_ids = {}

        arcs = self._bc_cov.arcs if self._bc_cov else []
        arc_id_to_arc = {}
        for arc in arcs:
            arc_id_to_arc[arc.id] = arc
        points = self._bc_cov.get_points(FilterLocation.PT_LOC_DISJOINT) if self._bc_cov else []
        point_id_to_point = {}
        for pt in points:
            point_id_to_point[pt.id] = pt

        self._map_old_to_new_transport_constituent_ids(
            transport_constituent_old_id_to_new_id, sediment_constituent_old_id_to_new_id
        )

        # Figure out which arcs are edgestrings, midstrings, nodestrings, or some combination thereof
        arc_id_to_edgestring_id = {}
        arc_id_to_midstring_id = {}
        arc_id_to_point_nodestring_id = {}
        arc_snap_to_string_dict = {
            'Point snap': arc_id_to_point_nodestring_id,
            'Edgestring snap': arc_id_to_edgestring_id,
            'Midstring snap': arc_id_to_midstring_id
        }
        if self.bc_comp:
            off_df = self.bc_comp.data.bc.solution_controls[self.bc_comp.data.bc.solution_controls.CARD.isin(
                ['OFF', '']
            )]
            mid_ids = []
            db_snap_ids = self.bc_comp.data.snap_types['ID'].tolist()
            db_snap_types = self.bc_comp.data.snap_types['SNAP'].tolist()
            off_arcs = off_df['STRING_ID'].tolist()
            off_arcs = [int(off_id) for off_id in off_arcs]
        else:
            mid_ids = []
            db_snap_ids = []
            db_snap_types = []
            off_arcs = []
        for comp_id, arc_ids in arc_comp_ids.items():
            comp_df = \
                self.bc_comp.data.comp_id_to_ids.loc[self.bc_comp.data.comp_id_to_ids['COMP_ID'] == comp_id]
            if comp_df.empty:
                continue
            bc_id = comp_df['BC_ID'].iloc[0]
            # if bc_id in off_arcs:
            #     continue
            comp_id_to_bc_id[comp_id] = bc_id
            bc_id_to_string_id[comp_id] = []
            is_mid = comp_id in mid_ids
            is_db = bc_id in db_snap_ids
            for arc_id in arc_ids:
                if is_mid:
                    next_string_id, string_id = self._get_string_id(arc_id, arc_id_to_midstring_id, next_string_id)
                    bc_id_to_string_id[comp_id].append(string_id)
                if is_db:
                    snap_type = db_snap_types[db_snap_ids.index(bc_id)]
                    next_string_id, string_id = self._get_string_id(
                        arc_id, arc_snap_to_string_dict[snap_type], next_string_id
                    )
                    bc_id_to_string_id[comp_id].append(string_id)
                if not is_mid and not is_db:  # Make sure it gets added
                    next_string_id, string_id = self._get_string_id(arc_id, arc_id_to_edgestring_id, next_string_id)
                    bc_id_to_string_id[comp_id].append(string_id)

        if self.bc_comp:
            fric_ids = self.bc_comp.data.bc.friction_controls['STRING_ID'].tolist()
            fric_is_mids = self.bc_comp.data.bc.friction_controls['REAL_05'].tolist()
            for index, _ in enumerate(fric_is_mids):
                if fric_is_mids[index] < 0.0 or math.isnan(fric_is_mids[index]):
                    fric_is_mids[index] = 0.0

            fric_is_mids = [int(float(mid) + 0.5) == 1 for mid in fric_is_mids]
        else:
            fric_ids = []
            fric_is_mids = []
        for comp_id, arc_ids in arc_comp_ids.items():
            bc_fric_arc_id_to_string_id[comp_id] = []
            comp_df = \
                self.bc_comp.data.comp_id_to_ids.loc[self.bc_comp.data.comp_id_to_ids['COMP_ID'] == comp_id]
            if comp_df.empty:
                continue
            fric_id = comp_df['FRICTION_ID'].iloc[0]
            flux_id = comp_df['FLUX_ID'].iloc[0]
            if fric_id in fric_ids:
                for arc_id in arc_ids:
                    idx = fric_ids.index(fric_id)
                    if fric_is_mids[idx]:
                        next_string_id, string_id = self._get_string_id(arc_id, arc_id_to_midstring_id, next_string_id)
                    else:
                        next_string_id, string_id = self._get_string_id(arc_id, arc_id_to_edgestring_id, next_string_id)
                    bc_fric_arc_id_to_string_id[comp_id].append(string_id)

            flux_df = self.bc_comp.data.flux.loc[self.bc_comp.data.flux['ID'] == flux_id]
            if not flux_df.empty:
                is_flux = flux_df['IS_FLUX'].iloc[0]
                if is_flux:
                    is_edge = flux_df['EDGESTRING'].iloc[0]
                    is_mid = flux_df['MIDSTRING'].iloc[0]
                    for arc_id in arc_ids:
                        if is_mid:
                            next_string_id, string_id = self._get_string_id(
                                arc_id, arc_id_to_midstring_id, next_string_id
                            )
                            self.flux_cards.append(['FLX', string_id])
                        if is_edge:
                            next_string_id, string_id = self._get_string_id(
                                arc_id, arc_id_to_edgestring_id, next_string_id
                            )
                            self.flux_cards.append(['FLX', string_id])

        if self.bc_comp:
            transport_ids = self.bc_comp.data.transport_assignments['TRAN_ID'].unique().tolist()
            use_transport_ids = self.bc_comp.data.uses_transport['TRAN_ID'].tolist()
            use_transport = self.bc_comp.data.uses_transport['USES_TRANSPORT'].tolist()

            sediment_ids = self.bc_comp.data.sediment_assignments['TRAN_ID'].unique().tolist()
            use_sediment_ids = self.bc_comp.data.uses_sediment['TRAN_ID'].tolist()
            use_sediment = self.bc_comp.data.uses_sediment['USES_SEDIMENT'].tolist()
            comp_id_df = self.bc_comp.data.comp_id_to_ids
        else:
            transport_ids = []
            use_transport_ids = []
            use_transport = []

            sediment_ids = []
            use_sediment_ids = []
            use_sediment = []
            comp_id_df = \
                pd.DataFrame([], columns=['COMP_ID', 'FRICTION_ID', 'FLUX_ID', 'TRANSPORT_ID', 'DIVERSION_ID', 'BC_ID'])

        for comp_id, arc_ids in arc_comp_ids.items():
            comp_id_row = comp_id_df.loc[comp_id_df['COMP_ID'] == comp_id]
            if comp_id_row.empty:
                continue
            transport_id = comp_id_row.TRANSPORT_ID.iloc[0]
            diversion_id = comp_id_row.DIVERSION_ID.iloc[0]
            next_string_id = self._add_transport_arc_snapping(
                transport_id, arc_ids, transport_ids, use_transport_ids, use_transport,
                self.bc_comp.data.transport_assignments, arc_snap_to_string_dict, next_string_id,
                bc_transport_id_to_string_id
            )
            next_string_id = self._add_transport_arc_snapping(
                transport_id, arc_ids, sediment_ids, use_sediment_ids, use_sediment,
                self.bc_comp.data.sediment_assignments, arc_snap_to_string_dict, next_string_id,
                bc_sediment_transport_id_to_string_id
            )
            next_string_id = self._add_diversion_arc_snapping(
                diversion_id, arc_ids, arc_snap_to_string_dict, next_string_id, bc_diversion_id_to_string_id
            )

        for comp_id in pt_comp_ids.keys():
            comp_id_row = comp_id_df.loc[comp_id_df['COMP_ID'] == comp_id]
            if comp_id_row.empty:
                continue
            transport_id = comp_id_row.TRANSPORT_ID.iloc[0]
            next_string_id = self._add_transport_point_snapping(
                transport_id, transport_ids, use_transport_ids, use_transport, self.bc_comp.data.transport_assignments,
                next_string_id, bc_transport_id_to_string_id
            )
            next_string_id = self._add_transport_point_snapping(
                transport_id, sediment_ids, use_sediment_ids, use_sediment, self.bc_comp.data.sediment_assignments,
                next_string_id, bc_sediment_transport_id_to_string_id
            )

        # Sort the dictionaries by string id so that we export from the lowest string id to the highest.
        snap_key_list = list(arc_snap_to_string_dict.keys())
        for dict_key in snap_key_list:
            arc_snap_to_string_dict[dict_key] = \
                {k: v for k, v in sorted(arc_snap_to_string_dict[dict_key].items(), key=lambda item: item[1])}

        # Snap nodestring points to original mesh
        snap_pt_arc = SnapNodeString()
        snap_pt_arc.set_grid(grid=self._grid, target_cells=False)
        # Snap edgestring exterior arcs to active mesh
        snap_ex_arc = SnapExteriorArc()
        snap_ex_arc.set_grid(grid=active_grid, target_cells=False)
        # Snap midstring interior arcs to original mesh
        snap_in_arc = SnapInteriorArc()
        snap_in_arc.set_grid(grid=self._grid, target_cells=False)
        # Snap points to original mesh
        snap_pt = SnapPoint()
        snap_pt.set_grid(grid=self._grid, target_cells=False)
        # Snap edgestring exterior arcs to active mesh
        for arc_id, string_id in arc_id_to_edgestring_id.items():
            arc = arc_id_to_arc[arc_id]
            result = snap_ex_arc.get_snapped_points(arc)
            ids = result['id']  # 0-based ids
            # Do we need to change the ids from small grid ids to original ids?
            # Probably not, because none of the points were removed from the small grid.
            for idx in range(1, len(ids)):
                _add_string_card(string_cards, 'EGS', string_id, ids[idx - 1] + 1, ids[idx] + 1)

        # Snap midstring interior arcs to original mesh
        for arc_id, string_id in arc_id_to_midstring_id.items():
            arc = arc_id_to_arc[arc_id]
            result = snap_in_arc.get_snapped_points(arc)
            ids = result['id']  # 0 based ids
            for idx in range(1, len(ids)):
                _add_string_card(string_cards, 'MDS', string_id, ids[idx - 1] + 1, ids[idx] + 1)

        # Snap nodestring points to original mesh
        for comp_id, pt_ids in pt_comp_ids.items():
            # Nodestrings from the transport constituents
            string_id = None
            if comp_id in bc_transport_id_to_string_id and comp_id in bc_sediment_transport_id_to_string_id:
                for constituent_id in bc_transport_id_to_string_id[comp_id].keys():
                    if bc_transport_id_to_string_id[comp_id][constituent_id]:
                        string_id = bc_transport_id_to_string_id[comp_id][constituent_id][0]
                        break
                if string_id is None:
                    for constituent_id in bc_sediment_transport_id_to_string_id[comp_id].keys():
                        if bc_sediment_transport_id_to_string_id[comp_id][constituent_id]:
                            string_id = bc_sediment_transport_id_to_string_id[comp_id][constituent_id][0]
                            break
            # Nodestrings from boundary conditions
            elif comp_id in comp_id_to_bc_id:
                string_id = next_string_id
                next_string_id += 1
            else:
                pt_dict = self.bc_comp.comp_to_xms.get(self.bc_comp.cov_uuid, {}).get(TargetType.point, {})
                for comp_id, pt_ids in pt_dict.items():
                    for pt_id in pt_ids:
                        # get the string_id
                        if comp_id in bc_transport_id_to_string_id:
                            for constituent_id in bc_transport_id_to_string_id[comp_id].keys():
                                if bc_transport_id_to_string_id[comp_id][constituent_id]:
                                    string_id = bc_transport_id_to_string_id[comp_id][constituent_id][0]
                                    break
                        result = snap_pt.get_snapped_points([point_id_to_point[pt_id]])
                        loc = result['location']
                        if loc is None or loc.size < 2:
                            continue

                        # set up shorthand for the dataframe
                        df = self.bc_comp.data.bc.solution_controls
                        # Get the self.bc_comp.data.bc.solution_controls["XY_ID_0"] from the row where
                        # self.bc_comp.data.bc.solution_controls["STRING_ID"] == comp_id

                        bc_df = self.bc_comp.data.comp_id_to_ids
                        bc_id_list = bc_df.loc[bc_df['COMP_ID'] == comp_id]['BC_ID'].to_list()
                        if len(bc_id_list) > 0:
                            bc_id = bc_id_list[0]
                            # The String_ID should be the bc_id.
                            # Check if these happen to be the same, and use the row index if possible
                            ts_id_list = df.loc[df['STRING_ID'] == bc_id]['XY_ID_0'].to_list()

                            if len(ts_id_list) > 0:
                                ts_id = ts_id_list[0]
                                self.bc_comp.data.bc.time_series[ts_id].x_location = loc[0][0]
                                self.bc_comp.data.bc.time_series[ts_id].y_location = loc[0][1]
            if string_id is not None:
                if comp_id in comp_id_to_bc_id:
                    bc_id_to_string_id[comp_id] = [string_id]
                pts = [point_id_to_point[pt_id] for pt_id in pt_ids]
                result = snap_pt.get_snapped_points(pts)
                ids = result['id']  # 0 based ids
                for node_id in ids:
                    _add_string_card(string_cards, 'NDS', string_id, node_id + 1)

        # Snap arc points to mesh, and save as nodestring
        for arc_id, string_id in arc_id_to_point_nodestring_id.items():
            arc = arc_id_to_arc[arc_id]
            result = snap_pt_arc.get_snapped_points(arc)
            ids = result['id']  # 0 based ids
            for node_id in ids:
                _add_string_card(string_cards, 'NDS', string_id, node_id + 1)

        # Build a string list
        labels = ['CARD', 'ID', 'ID_0', 'ID_1']
        df = pd.DataFrame.from_records(string_cards, columns=labels)
        for x in range(1, len(labels)):
            df[labels[x]] = df[labels[x]].astype(dtype='Int64')
        self.sim.boundary_conditions.boundary_strings = df
        # Set param class mesh
        self._load_geometry(mat_cell)
        # Set properties to param class with old component string ids mapped to new string ids
        self._setup_from_data(
            mat_id_to_string_id, bc_fric_arc_id_to_string_id, bc_id_to_string_id, bc_transport_id_to_string_id,
            transport_constituent_old_id_to_new_id, bc_sediment_transport_id_to_string_id,
            sediment_constituent_old_id_to_new_id
        )
        self._setup_diversions(bc_diversion_id_to_string_id)
        self._setup_sediment()
        self._setup_sediment_materials(sed_mat_id_to_string_id, sediment_constituent_old_id_to_new_id)
        # Set param class hot start dataset
        self._set_hotstart_datasets()

        # Snap nodal output points to mesh and add to BcData
        # Get output coverage
        output_coverage = self._output_cov
        simulation = self.sim
        grid_locations = self._grid.ugrid.locations
        add_nodal_output(output_coverage, grid_locations, simulation)

        return self.sim, self.sim_name, self.hotstarts

    def _add_transport_point_snapping(
            self, transport_id, transport_ids, use_transport_ids, use_transport, transport_assignments, next_string_id,
            bc_transport_id_to_string_id
    ):
        """Adds point sets that have transport properties to the snapped strings.

        Args:
            transport_id (int): The transport id.
            transport_ids (list): A unique list of transport ids for point sets that may or may not use transport.
            use_transport_ids (list): A list of transport ids for point sets.
            use_transport (list): A list of 0 or 1 that indicates whether transport was specified for this constituent
            and point set combination. This is parallel to use_transport_ids.
            transport_assignments (DataFrame): The dataframe to get transport assignments out of.
            next_string_id (int): The id of the next string.
            bc_transport_id_to_string_id (dict): A dictionary of the transport id to a dictionary of the component's
            constituent id to the new string id.

        Returns:
            The next string id.
        """
        if transport_id not in transport_ids:
            return next_string_id
        use_idx = use_transport_ids.index(transport_id)
        if not use_transport[use_idx]:
            return next_string_id
        snap_df = transport_assignments.loc[
            transport_assignments['TRAN_ID'] == transport_id, ('CONSTITUENT_ID', 'TYPE', 'SNAPPING')]
        string_id = next_string_id
        next_string_id += 1
        for (_, constituent_id, boundary_type, _) in snap_df.itertuples():
            if boundary_type == 'None':
                continue
            bc_transport_id_to_string_id[transport_id][constituent_id] = [string_id]
        return next_string_id

    def _add_transport_arc_snapping(
            self, transport_id, arc_ids, transport_ids, use_transport_ids, use_transport, transport_assignments,
            arc_snap_to_string_dict, next_string_id, bc_transport_id_to_string_id
    ):
        """Adds arcs that have transport properties to the snapped strings.

        Args:
            transport_id (int): The transport id.
            arc_ids (list): A list of arc attribute ids that use the component id.
            transport_ids (list): A unique list of transport ids for arcs that may or may not use transport.
            use_transport_ids (list): A list of transport ids for arcs.
            use_transport (list): A list of 0 or 1 that indicates whether transport was specified for this constituent
            and arc combination. This is parallel to use_transport_ids.
            transport_assignments (DataFrame): The dataframe to get transport assignments out of.
            arc_snap_to_string_dict (dict): A dictionary of snapping types to dictionaries of arc ids to string ids.
            next_string_id (int): The id of the next string.
            bc_transport_id_to_string_id (dict): A dictionary of the transport id to a dictionary of the component's
            constituent id to the new string id.

        Returns:
            The next string id.
        """
        if transport_id not in transport_ids:
            return next_string_id
        use_idx = use_transport_ids.index(transport_id)
        if not use_transport[use_idx]:
            return next_string_id
        snap_df = transport_assignments.loc[
            transport_assignments['TRAN_ID'] == transport_id, ('CONSTITUENT_ID', 'TYPE', 'SNAPPING')]
        for (_, constituent_id, boundary_type, snap) in snap_df.itertuples():
            if boundary_type == 'None':
                continue
            elif boundary_type == 'Natural':
                snap = 'Edgestring snap'
                string_dict = arc_snap_to_string_dict[snap]
            elif boundary_type in ['Dirichlet', 'Equilibrium']:
                snap = 'Point snap'
                string_dict = arc_snap_to_string_dict[snap]
            else:
                string_dict = arc_snap_to_string_dict[snap]
            for arc_id in arc_ids:
                next_string_id, string_id = self._get_string_id(arc_id, string_dict, next_string_id)
                bc_transport_id_to_string_id[transport_id][constituent_id].append(string_id)
        return next_string_id

    def _add_diversion_arc_snapping(
            self, diversion_id, arc_ids, arc_snap_to_string_dict, next_string_id, bc_diversion_id_to_string_id
    ):
        """Adds arcs that have sediment diversion properties to the snapped strings.

        Args:
            diversion_id (int): The sediment diversion id.
            arc_ids (list): A list of arc attribute ids that use the sediment diversion id.
            arc_snap_to_string_dict (dict): A dictionary of snapping types to dictionaries of arc ids to string ids.
            next_string_id (int): The id of the next string.
            bc_diversion_id_to_string_id (dict): A dictionary of the sediment diversion id to a list of the
            new string ids.

        Returns:
            The next string id.
        """
        snap_df = self.bc_comp.data.sediment_diversions.loc[
            self.bc_comp.data.sediment_diversions['DIV_ID'] == diversion_id, ('SNAPPING',)]
        for (_, snap) in snap_df.itertuples():
            string_dict = arc_snap_to_string_dict[snap]
            for arc_id in arc_ids:
                next_string_id, string_id = self._get_string_id(arc_id, string_dict, next_string_id)
                bc_diversion_id_to_string_id[diversion_id].append(string_id)
        return next_string_id

    def _map_old_to_new_transport_constituent_ids(
            self, transport_constituent_old_id_to_new_id, sediment_constituent_old_id_to_new_id
    ):
        """Fills in the map of old transport constituent ids to new transport constituent ids.

        The new transport constituent ids should be sequential from 1 to n.

        Args:
            transport_constituent_old_id_to_new_id (dict): A dictionary of old to new ids to be filled in.
            sediment_constituent_old_id_to_new_id (dict): A dictionary of old to new ids to be filled in.
        """
        next_constituent_id = 1
        if self.trans_comp:
            # Map old transport constituent ids to new ids
            if self.trans_comp.data.param_control.salinity:
                transport_constituent_old_id_to_new_id[1] = next_constituent_id
                self._hot_start_names['icon-salinity'] = f'icon {next_constituent_id}'
                next_constituent_id += 1
            if self.trans_comp.data.param_control.temperature:
                transport_constituent_old_id_to_new_id[2] = next_constituent_id
                self._hot_start_names['icon-temperature'] = f'icon {next_constituent_id}'
                next_constituent_id += 1
            if self.trans_comp.data.param_control.vorticity:
                transport_constituent_old_id_to_new_id[3] = next_constituent_id
                self._hot_start_names['icon-vorticity'] = f'icon {next_constituent_id}'
                next_constituent_id += 1
            old_ids = self.trans_comp.data.user_constituents['ID'].values.tolist()
            concentrations = self.trans_comp.data.user_constituents['CONC'].values.tolist()
            con_names = self.trans_comp.data.user_constituents['NAME'].values.tolist()
            transport_constituent_old_id_to_new_id.update(
                {
                    old_id: idx + next_constituent_id
                    for idx, old_id in enumerate(old_ids)
                }
            )
            con_ids = [idx + next_constituent_id for idx in range(len(concentrations))]
            self.general_constituents = \
                pd.DataFrame({'ID': con_ids, 'NAME': con_names, 'CONC': concentrations})
            self.num_transport = next_constituent_id + len(con_ids) - 1
            for i in range(len(con_ids)):
                self._hot_start_names[f'icon-concentration-{i + 1}'] = f'icon {con_ids[i]}'
            next_constituent_id += len(con_ids)
        if self.sed_trans_comp:
            old_ids = self.sed_trans_comp.data.param_control.sand['ID'].values.tolist()
            for i in range(len(old_ids)):
                self._hot_start_names[f'icon-clay-{i + 1}'] = f'icon {next_constituent_id + i}'
            sediment_constituent_old_id_to_new_id.update(
                {
                    old_id: idx + next_constituent_id
                    for idx, old_id in enumerate(old_ids)
                }
            )
            self.num_transport += len(old_ids)
            next_constituent_id += len(old_ids)
            old_ids = self.sed_trans_comp.data.param_control.clay['ID'].values.tolist()
            for i in range(len(old_ids)):
                self._hot_start_names[f'icon-sand-{i + 1}'] = f'icon {next_constituent_id + i}'
            sediment_constituent_old_id_to_new_id.update(
                {
                    old_id: idx + next_constituent_id
                    for idx, old_id in enumerate(old_ids)
                }
            )
            self.num_transport += len(old_ids)

    @staticmethod
    def _get_string_id(feature_id, string_dict, next_string_id):
        """Gets the string id for this feature."""
        # check to see if string already exists
        if feature_id not in string_dict:
            string_id = next_string_id
            next_string_id += 1
            string_dict[feature_id] = string_id
        else:
            string_id = string_dict[feature_id]
        return next_string_id, string_id

    def _create_active_grid(self, mat_cell, mat_id_to_string_id, next_string_id, string_cards):
        """Creates a grid with the elements that are part of the OFF material removed.

        Args:
            mat_cell (list): A list of integer material ids. The list is parallel to the cell ids.
            mat_id_to_string_id (dict): A dictionary of integer material id to a list of new string ids.
            next_string_id (int): The next string id.
            string_cards (list): A list of the string cards for export.

        Returns:
            mat_cell (list): A list of integer material ids with 0 removed. The list is parallel to the cell ids.
            next_string_id (int): The next string id.
            active_grid (UGrid2d): The grid that has all elements under the off material removed.
        """
        cell_stream = self._grid.ugrid.cellstream
        # All elements should have a cell type, a cell point count, and 3 point ids
        if not self._grid.check_all_cells_are_of_type(UGrid.cell_type_enum.TRIANGLE):
            raise Exception('Not all elements are triangles. Save aborted.')
        split_cell_stream = [cell_stream[i:i + 5] for i in range(0, len(cell_stream), 5)]
        active_cell_stream = []
        for idx, cell in enumerate(split_cell_stream):
            if mat_cell[idx] != 0:
                active_cell_stream.extend(cell)
        xmugrid = UGrid(self._grid.ugrid.locations, active_cell_stream)
        co_builder = UGridBuilder()
        co_builder.set_is_2d()
        co_builder.set_ugrid(xmugrid)
        active_grid = co_builder.build_grid()
        # Add an OFF material string
        mat_string_id = next_string_id
        next_string_id += 1
        mat_id_to_string_id[0] = [mat_string_id]
        _add_string_card(string_cards, 'MTS', mat_string_id, mat_string_id)
        mat_cell = [mat_string_id if cell == 0 else cell for cell in mat_cell]
        self.off_material_string_id = mat_string_id
        return mat_cell, next_string_id, active_grid

    def _setup_from_data(
            self, mat_id_to_string_id, bc_fric_arc_id_to_string_id, bc_id_to_string_id, bc_transport_id_to_string_id,
            transport_constituent_old_id_to_new_id, bc_sediment_transport_id_to_string_id,
            sediment_constituent_old_id_to_new_id
    ):
        """Sets up the adhparams object from the components.

        Args:
            mat_id_to_string_id (dict): Old material id to new.
            bc_fric_arc_id_to_string_id (dict): Old id to new string id.
            bc_id_to_string_id (dict): Old id to new string id.
            bc_transport_id_to_string_id (dict): Old id to new string id.
            transport_constituent_old_id_to_new_id (dict): Old id to new string id.
            bc_sediment_transport_id_to_string_id (dict): Old id to new string id.
            sediment_constituent_old_id_to_new_id (dict): Old id to new string id.
        """
        try:
            bc_time_to_new_time = {}
            # structure_time_to_new_time = {}
            new_data = []
            # add time series
            series_id = 0
            out_option = self.sim_data.param_control.output_control.output_control_option
            is_oc = out_option == 'Specify output frequency (OC)'
            oc_ts_id = self.sim_data.param_control.output_control.oc_time_series_id
            max_time_step_size_ts_id = self.sim_data.param_control.time_control.max_time_step_size_time_series
            os_ts_id = self.sim_data.info.attrs['os_time_series']
            # Don't set anything in self.sim.model_control yet, the entire class will be set later.
            if is_oc and oc_ts_id > 0 and oc_ts_id in self.sim_data.time_series:
                series_id += 1
                new_series = copy.deepcopy(self.sim_data.time_series[oc_ts_id])
                new_series.series_id = series_id
                self.sim.boundary_conditions.time_series[series_id] = new_series
                oc_ts_id = series_id
            elif not is_oc and os_ts_id > 0 and os_ts_id in self.sim_data.time_series:
                series_id += 1
                new_series = copy.deepcopy(self.sim_data.time_series[os_ts_id])
                new_series.series_id = series_id
                self.sim.boundary_conditions.time_series[series_id] = new_series
            if max_time_step_size_ts_id > 0 and max_time_step_size_ts_id in self.sim_data.time_series:
                series_id += 1
                new_series = copy.deepcopy(self.sim_data.time_series[max_time_step_size_ts_id])
                new_series.series_id = series_id
                self.sim.boundary_conditions.time_series[series_id] = new_series
                max_time_step_size_ts_id = series_id
            if self.bc_comp:
                for bc_s_id, series in self.bc_comp.data.bc.time_series.items():
                    series_id += 1
                    new_series = copy.deepcopy(series)
                    new_series.series_id = series_id
                    bc_time_to_new_time[bc_s_id] = series_id
                    self.sim.boundary_conditions.time_series[series_id] = new_series

            for mat_id, use_curve in self.mat_comp.data.materials.friction_use_seasonal.items():
                if not use_curve:
                    continue
                curve_id = self.mat_comp.data.materials.friction_use_seasonal[mat_id]
                series = self.mat_comp.data.materials.time_series[curve_id]
                for new_mat_id in mat_id_to_string_id[mat_id]:
                    series_id += 1
                    new_series = copy.deepcopy(series)
                    new_series.series_id = series_id
                    self.sim.boundary_conditions.time_series[series_id] = new_series
                    new_data.append(['FR', 'SRF', new_mat_id, series_id, None, None])
            for mat_id, use_curve in self.mat_comp.data.materials.material_use_meteorological.items():
                if not use_curve:
                    continue
                curve_id = self.mat_comp.data.materials.material_meteorological_curve[mat_id]
                series = self.mat_comp.data.materials.time_series[curve_id]
                for new_mat_id in mat_id_to_string_id[mat_id]:
                    series_id += 1
                    new_series = copy.deepcopy(series)
                    new_series.series_id = series_id
                    self.sim.boundary_conditions.time_series[series_id] = new_series
                    new_data.append(['NB', 'OVL', new_mat_id, series_id, None, None])
            # if structures:
            #     for id, series in structures.data.bc.time_series.items():
            #         series_id += 1
            #         new_series = copy.deepcopy(series)
            #         new_series.series_id = series_id
            #         structure_time_to_new_time[id] = series_id
            #         self.sim.boundary_conditions.time_series[series_id] = new_series
            if self.bc_comp:
                for row_data in self.bc_comp.data.nb_out.itertuples():
                    for new_id1, new_id2 in zip(
                            bc_id_to_string_id[row_data.OUT_COMP_ID], bc_id_to_string_id[row_data.IN_COMP_ID]
                    ):
                        new_data.append(['NB', 'OUT', new_id1, new_id2, bc_time_to_new_time[row_data.SERIES_ID], None])

                for row_data in self.bc_comp.data.bc.solution_controls.itertuples():
                    card = row_data.CARD
                    card_2 = row_data.CARD_2
                    old_id = int(row_data.STRING_ID)
                    comp_id_data = self.bc_comp.data.comp_id_to_ids[self.bc_comp.data.comp_id_to_ids.BC_ID == old_id]
                    comp_id = comp_id_data.COMP_ID.iloc[0]
                    if not card or card == 'OFF':
                        for new_id in bc_id_to_string_id[comp_id]:
                            new_data.append(['OFF', new_id, None, None, None, None])
                        continue
                    xy_0 = int(row_data.XY_ID_0)
                    xy_1 = int(row_data.XY_ID_1)
                    xy_2 = int(row_data.XY_ID_2)
                    if card_2 == 'OUT':
                        for new_id1, new_id2 in zip(bc_id_to_string_id[comp_id], bc_id_to_string_id[xy_0]):
                            new_data.append([card, card_2, new_id1, new_id2, bc_time_to_new_time[xy_1], None])
                    elif card == 'DB' and card_2 == 'OVL':
                        for new_id in bc_id_to_string_id[comp_id]:
                            new_data.append(
                                [card, card_2, new_id, bc_time_to_new_time[xy_0], bc_time_to_new_time[xy_1], None]
                            )
                    elif card_2 == 'OVH':
                        for new_id in bc_id_to_string_id[comp_id]:
                            new_data.append(
                                [
                                    card, card_2, new_id, bc_time_to_new_time[xy_0], bc_time_to_new_time[xy_1],
                                    bc_time_to_new_time[xy_2]
                                ]
                            )
                    elif card_2 in ['OF', 'TID']:
                        for new_id in bc_id_to_string_id[comp_id]:
                            new_data.append([card, card_2, new_id, None, None, None])
                    elif comp_id in bc_id_to_string_id:
                        for new_id in bc_id_to_string_id[comp_id]:
                            if xy_0 in bc_time_to_new_time:
                                new_data.append([card, card_2, new_id, bc_time_to_new_time[xy_0], None, None])

                for row_data in self.bc_comp.data.transport_assignments.itertuples():
                    old_id = int(row_data.TRAN_ID)
                    old_constituent_id = int(row_data.CONSTITUENT_ID)
                    series_id = int(row_data.SERIES_ID)
                    old_type = row_data.TYPE
                    if old_type == 'Natural':
                        type_str = 'NB'
                    elif old_type == 'Dirichlet':
                        type_str = 'DB'
                    else:
                        continue
                    # Check that we have the new string id.
                    if old_id in bc_transport_id_to_string_id and \
                            old_constituent_id in bc_transport_id_to_string_id[old_id] and \
                            old_constituent_id in transport_constituent_old_id_to_new_id:
                        constituent_id = transport_constituent_old_id_to_new_id[old_constituent_id]
                        for new_id in bc_transport_id_to_string_id[old_id][old_constituent_id]:
                            new_data.append(
                                [type_str, 'TRN', new_id, constituent_id, bc_time_to_new_time[series_id], None]
                            )

                for row_data in self.bc_comp.data.sediment_assignments.itertuples():
                    old_id = int(row_data.TRAN_ID)
                    old_constituent_id = int(row_data.CONSTITUENT_ID)
                    series_id = int(row_data.SERIES_ID)
                    old_type = row_data.TYPE
                    if old_type == 'Natural':
                        type_str = 'NB'
                    elif old_type == 'Dirichlet':
                        type_str = 'DB'
                    elif old_type == 'Equilibrium':
                        type_str = 'EQ'
                    else:
                        continue
                    if old_id in bc_sediment_transport_id_to_string_id and \
                            old_constituent_id in bc_sediment_transport_id_to_string_id[old_id] and \
                            old_constituent_id in sediment_constituent_old_id_to_new_id:
                        constituent_id = sediment_constituent_old_id_to_new_id[old_constituent_id]
                        for new_id in bc_sediment_transport_id_to_string_id[old_id][old_constituent_id]:
                            new_data.append(
                                [type_str, 'TRN', new_id, constituent_id, bc_time_to_new_time[series_id], None]
                            )
            if self.off_material_string_id > 0:
                new_data.append(['OFF', self.off_material_string_id, None, None, None, None])
            bc_df = pd.DataFrame(
                data=new_data, columns=['CARD', 'CARD_2', 'STRING_ID', 'XY_ID_0', 'XY_ID_1', 'XY_ID_2']
            )
            bc_df.sort_values(by='STRING_ID', inplace=True)
            int_columns = ['STRING_ID', 'XY_ID_0', 'XY_ID_1', 'XY_ID_2']
            for int_column in int_columns:
                bc_df[int_column] = bc_df[int_column].astype(dtype='Int64')
            self.sim.boundary_conditions.solution_controls = bc_df
            new_sdr_data = []
            if self.bc_comp:
                for row_data in self.bc_comp.data.bc.stage_discharge_boundary.itertuples():
                    card = row_data.CARD
                    old_id = int(row_data.S_ID)
                    coef_a = float(row_data.COEF_A)
                    coef_b = float(row_data.COEF_B)
                    coef_c = float(row_data.COEF_C)
                    coef_d = float(row_data.COEF_D)
                    coef_e = float(row_data.COEF_E)
                    new_sdr_data.append([card, bc_id_to_string_id[old_id], coef_a, coef_b, coef_c, coef_d, coef_e])
            sdr_df = pd.DataFrame(
                data=new_sdr_data, columns=['CARD', 'S_ID', 'COEF_A', 'COEF_B', 'COEF_C', 'COEF_D', 'COEF_E']
            )
            sdr_df.sort_values(by='S_ID', inplace=True)
            # self.sim.boundary_conditions.stage_discharge_boundary =\
            #     self.sim.boundary_conditions.stage_discharge_boundary.append(sdr_df)
            self.sim.boundary_conditions.stage_discharge_boundary = \
                pd.concat([self.sim.boundary_conditions.stage_discharge_boundary, sdr_df])

            # if structures:
            #     self._add_weirs(structures, structures_old_id_to_new_id)
            #     self._add_flap_gates(structures, structures_old_id_to_new_id)
            #     self._add_sluice_gates(structures, structures_old_id_to_new_id, structure_time_to_new_time)

            new_fric_data = []
            for row_data in self.mat_comp.data.materials.friction.itertuples():
                card = row_data.CARD
                card_2 = row_data.CARD_2
                old_id = int(row_data.STRING_ID)
                real_1 = float(row_data.REAL_01)
                real_2 = float(row_data.REAL_02)
                real_3 = float(row_data.REAL_03)
                real_4 = float(row_data.REAL_04)
                real_5 = float(row_data.REAL_05)
                if old_id in mat_id_to_string_id:
                    for new_mat_id in mat_id_to_string_id[old_id]:
                        new_fric_data.append([card, card_2, new_mat_id, real_1, real_2, real_3, real_4, real_5])
            if self.bc_comp:
                for row_data in self.bc_comp.data.bc.friction_controls.itertuples():
                    card = row_data.CARD
                    card_2 = row_data.CARD_2
                    old_id = int(row_data.STRING_ID)
                    comp_id_series = self.bc_comp.data.comp_id_to_ids[
                        self.bc_comp.data.comp_id_to_ids.FRICTION_ID == old_id].COMP_ID
                    comp_id = int(comp_id_series.iloc[0])
                    real_1 = float(row_data.REAL_01)
                    real_2 = float(row_data.REAL_02)
                    real_3 = float(row_data.REAL_03)
                    real_4 = float(row_data.REAL_04)
                    real_5 = float(row_data.REAL_05)
                    if comp_id in bc_fric_arc_id_to_string_id:
                        for new_id in bc_fric_arc_id_to_string_id[comp_id]:
                            new_fric_data.append([card, card_2, new_id, real_1, real_2, real_3, real_4, real_5])
            fric_df = pd.DataFrame(
                data=new_fric_data,
                columns=['CARD', 'CARD_2', 'STRING_ID', 'REAL_01', 'REAL_02', 'REAL_03', 'REAL_04', 'REAL_05']
            )
            fric_df.sort_values(by='STRING_ID', inplace=True)
            if self.sim.boundary_conditions.friction_controls.empty:
                self.sim.boundary_conditions.friction_controls = fric_df
            else:
                self.sim.boundary_conditions.friction_controls = \
                    pd.concat([self.sim.boundary_conditions.friction_controls, fric_df])

            for old_id, new_ids in mat_id_to_string_id.items():
                transport_values = self.mat_comp.data.materials.material_properties[old_id].transport_properties
                for new_mat_id in new_ids:
                    self.sim.boundary_conditions.material_properties[new_mat_id] = \
                        copy.deepcopy(self.mat_comp.data.materials.material_properties[old_id])
                    # Clear out the dictionary, so we can add the values with the new constituent keys.
                    self.sim.boundary_conditions.material_properties[new_mat_id].transport_properties = {}
                    new_transport_values = \
                        self.sim.boundary_conditions.material_properties[new_mat_id].transport_properties
                    all_old_ids = list(transport_constituent_old_id_to_new_id.keys())
                    for old_constituent_id, values in transport_values.items():
                        if old_constituent_id in transport_constituent_old_id_to_new_id:
                            new_constituent_id = transport_constituent_old_id_to_new_id[old_constituent_id]
                            new_transport_values[new_constituent_id] = values
                            all_old_ids.remove(old_constituent_id)
                    # Add default values for any constituent that was not specified.
                    for old_constituent_id in all_old_ids:
                        new_constituent_id = transport_constituent_old_id_to_new_id[old_constituent_id]
                        new_transport_values[new_constituent_id] = MaterialTransportProperties()
            self.sim.model_control = copy.deepcopy(self.sim_data.param_control)
            self.sim.model_control.output_control.output_flow_strings = pd.DataFrame(
                data=self.flux_cards, columns=['CARD', 'S_ID']
            )
            self.sim.model_control.output_control.oc_time_series_id = oc_ts_id
            self.sim.model_control.time_control.max_time_step_size_time_series = max_time_step_size_ts_id
            if self.num_transport:
                self.sim.model_control.operation_parameters.transport = self.num_transport
            if self.trans_comp:
                self.sim.model_control.constituent_properties = self.trans_comp.data.param_control
                self.sim.model_control.constituent_properties.general_constituents = self.general_constituents
                if 1 in transport_constituent_old_id_to_new_id:
                    self.sim.model_control.constituent_properties.salinity_id = \
                        transport_constituent_old_id_to_new_id[1]
                if 2 in transport_constituent_old_id_to_new_id:
                    self.sim.model_control.constituent_properties.temperature_id = \
                        transport_constituent_old_id_to_new_id[2]
                if 3 in transport_constituent_old_id_to_new_id:
                    self.sim.model_control.constituent_properties.vorticity_id = \
                        transport_constituent_old_id_to_new_id[3]
            if self.sed_trans_comp:
                self.sim.model_control.sediment_constituent_properties = self.sed_trans_comp.data.param_control
                self.sim.model_control.sediment_constituent_properties.sand.replace(
                    sediment_constituent_old_id_to_new_id, inplace=True
                )
                self.sim.model_control.sediment_constituent_properties.clay.replace(
                    sediment_constituent_old_id_to_new_id, inplace=True
                )

        except Exception:
            self._logger.error('Unable to retrieve AdH data from XMS for exporting.')
            raise

    def _setup_diversions(self, bc_diversion_id_to_string_id):
        """Sets sediment diversion data in an "adhparam" object.

        Args:
            bc_diversion_id_to_string_id (dict): A dictionary of sediment diversion component ids to string ids.
        """
        try:
            # Add sediment diversions
            new_sdv_data = []
            if self.bc_comp:
                for row_data in self.bc_comp.data.sediment_diversions.itertuples():
                    string_ids = bc_diversion_id_to_string_id[int(row_data.DIV_ID)]
                    for string_id in string_ids:
                        new_sdv_data.append(
                            [
                                'SDV', string_id,
                                float(row_data.TOP),
                                float(row_data.BOTTOM),
                                float(row_data.BOTTOM_MAIN)
                            ]
                        )
            self.sim.model_control.sediment_properties.sediment_diversion = \
                pd.DataFrame(data=new_sdv_data, columns=['CARD', 'S_ID', 'TOP', 'BOTTOM', 'BOTTOM_MAIN'])
        except Exception:
            self._logger.error('Unable to retrieve AdH sediment diversion data from XMS for exporting.')
            raise

    def _setup_sediment(self):
        """Sets the sediment parameters in adhparam."""
        if self.sed_mat_comp is None:
            return
        self.sim.model_control.sediment_properties.cohesive_settling = \
            SedimentMaterialsIO.CSV_OPTIONS.index(
                self.sed_mat_comp.data.info.attrs['cohesive_settling_velocity_method'])
        self.sim.model_control.sediment_properties.cohesive_settling_a = self.sed_mat_comp.data.info.attrs['a_csv']
        self.sim.model_control.sediment_properties.cohesive_settling_b = self.sed_mat_comp.data.info.attrs['b_csv']
        self.sim.model_control.sediment_properties.cohesive_settling_m = self.sed_mat_comp.data.info.attrs['m_csv']
        self.sim.model_control.sediment_properties.cohesive_settling_n = self.sed_mat_comp.data.info.attrs['n_csv']
        self.sim.model_control.sediment_properties.wind_wave_stress = \
            SedimentMaterialsIO.WWS_OPTIONS.index(self.sed_mat_comp.data.info.attrs['wind_wave_shear_method'])
        self.sim.model_control.sediment_properties.suspended_entrainment = \
            SedimentMaterialsIO.NSE_OPTIONS.index(
                self.sed_mat_comp.data.info.attrs['noncohesive_suspended_method']) - 1
        self.sim.model_control.sediment_properties.bedload_entrainment = \
            SedimentMaterialsIO.NBE_OPTIONS.index(self.sed_mat_comp.data.info.attrs['noncohesive_bedload_method']) - 1
        self.sim.model_control.sediment_properties.critical_shear_sand = \
            self.sed_mat_comp.data.info.attrs['critical_shear_sand']
        self.sim.model_control.sediment_properties.critical_shear_clay = \
            self.sed_mat_comp.data.info.attrs['critical_shear_clay']
        self.sim.model_control.sediment_properties.hiding_factor = \
            SedimentMaterialsIO.HID_OPTIONS.index(self.sed_mat_comp.data.info.attrs['noncohesive_hiding_method'])
        self.sim.model_control.sediment_properties.hiding_factor_exponent = \
            self.sed_mat_comp.data.info.attrs['hiding_factor']
        self.sim.model_control.sediment_properties.use_infiltration_factor = \
            int(self.sed_mat_comp.data.info.attrs['use_sediment_infiltration_factor'])
        self.sim.model_control.sediment_properties.infiltration_factor = \
            self.sed_mat_comp.data.info.attrs['sediment_infiltration_factor']
        self.sim.model_control.sediment_properties.bed_layer_thickness_protocol = \
            int(self.sed_mat_comp.data.info.attrs['bed_layer_assignment_protocol'])

    def _setup_sediment_materials(self, mat_old_id_to_new_ids, constituent_old_id_to_new_id):
        """Sets the sediment material parameters in adhparam.

        Args:
            mat_old_id_to_new_ids (dict): A dictionary of old sediment material ids to new material string ids.
            constituent_old_id_to_new_id (dict): A dictionary of old sediment transport constituent ids to new ids.
        """
        bed_layer_id_to_new_id = self._set_global_sediment(constituent_old_id_to_new_id)

        displacement_off = []
        local_scour = []
        diffusion = []
        for mat_id, string_ids in mat_old_id_to_new_ids.items():
            material = self.sed_mat_comp.data.materials[mat_id]
            if mat_id == SedimentMaterialsIO.UNASSIGNED_MAT:
                continue
            for string_id in string_ids:
                # Append a consolidation DataFrame
                consolidation = self._get_material_consolidation(material, string_id)
                self.sim.model_control.sediment_properties.material_consolidation = \
                    pd.concat([self.sim.model_control.sediment_properties.material_consolidation,
                               consolidation])[consolidation.columns.tolist()]
                self.sim.model_control.sediment_properties.material_consolidation.MATERIAL_ID = \
                    self.sim.model_control.sediment_properties.material_consolidation.MATERIAL_ID.astype('int64')
                self.sim.model_control.sediment_properties.material_consolidation.TIME_ID = \
                    self.sim.model_control.sediment_properties.material_consolidation.TIME_ID.astype('int64')

                # Append a cohesive bed layers DataFrame
                if self.sed_mat_comp.data.info.attrs['use_cohesive_bed_layers'] and \
                        material.bed_layer_cohesive_override:
                    cohesive = self._get_material_bed_layer_cohesive(material, bed_layer_id_to_new_id, string_id)
                    self.sim.model_control.sediment_properties.material_cohesive_bed = \
                        pd.concat([self.sim.model_control.sediment_properties.material_cohesive_bed,
                                   cohesive])[cohesive.columns.tolist()]
                    self.sim.model_control.sediment_properties.material_cohesive_bed.MATERIAL_ID = \
                        self.sim.model_control.sediment_properties.material_cohesive_bed.MATERIAL_ID.astype('int64')
                    self.sim.model_control.sediment_properties.material_cohesive_bed.BED_LAYER_ID = \
                        self.sim.model_control.sediment_properties.material_cohesive_bed.BED_LAYER_ID.astype('int64')

                # Append a bed layers DataFrame
                if material.bed_layer_override:
                    bed_layers = self._get_material_bed_layers(material, bed_layer_id_to_new_id, string_id)
                    self.sim.model_control.sediment_properties.material_bed_layers = \
                        pd.concat([self.sim.model_control.sediment_properties.material_bed_layers,
                                   bed_layers])[bed_layers.columns.tolist()]
                    self.sim.model_control.sediment_properties.material_cohesive_bed.MATERIAL_ID.astype('int64')
                    self.sim.model_control.sediment_properties.material_cohesive_bed.BED_LAYER_ID.astype('int64')

                # Append a bed layers sediment constituents DataFrame
                constituents = self._get_material_grain_fractions(
                    material, string_id, bed_layer_id_to_new_id, constituent_old_id_to_new_id
                )
                self.sim.model_control.sediment_properties.bed_layer_grain_fractions = pd.concat(
                    [self.sim.model_control.sediment_properties.bed_layer_grain_fractions, constituents]
                )[constituents.columns.tolist()]

                # build these so the data frames can be built all at once
                if material.displacement_off:
                    displacement_off.append(['MP NDM', string_id])
                if material.local_scour:
                    local_scour.append(['MP LSM', string_id])
                if material.use_bedload_diffusion:
                    diffusion.append(['MP BLD', string_id, material.bedload_diffusion])
        # set new dataframes
        self.sim.model_control.sediment_properties.material_displacement_off = \
            pd.DataFrame(data=displacement_off, columns=['CARD', 'MATERIAL_ID'])
        self.sim.model_control.sediment_properties.material_local_scour = \
            pd.DataFrame(data=local_scour, columns=['CARD', 'MATERIAL_ID'])
        self.sim.model_control.sediment_properties.material_diffusion = \
            pd.DataFrame(data=diffusion, columns=['CARD', 'MATERIAL_ID', 'DIFFUSION'])

        if self.sed_mat_comp:
            sed_mat_props = {}
            for sed_mat_id in self.sed_mat_comp.data.materials.keys():
                props = self.sed_mat_comp.data.materials[sed_mat_id].sediment_material_properties
                sed_mat_props[sed_mat_id] = [mat_tran_prop for _index, mat_tran_prop in props.items()]
            self.sim.boundary_conditions.sediment_material_properties = sed_mat_props

    def _set_global_sediment(self, constituent_old_id_to_new_id):
        """Sets the global sediment dataframes.

        The materials can override these global options.

        Args:
            constituent_old_id_to_new_id (dict): A dictionary of old sediment transport constituent ids to new ids.

        Returns:
            A dictionary of old bed layer ids to new.
        """
        if self.sed_mat_comp is None:
            return
        material = self.sed_mat_comp.data.materials[SedimentMaterialsIO.UNASSIGNED_MAT]
        self.sim.model_control.sediment_properties.number_bed_layers = len(material.bed_layers.index)
        self.sim.model_control.sediment_properties.number_consolidation_times = len(material.consolidation.index)
        # Get a mapping from the old bed layer IDs to new
        # The new layer IDs decrease so the first layer (top) has the highest value
        old_bed_layer_ids = material.bed_layers['layer_id'].values.tolist()
        layer_count = len(old_bed_layer_ids)
        bed_layer_id_to_new_id = {layer_id: layer_count - idx for idx, layer_id in enumerate(old_bed_layer_ids)}

        # Set the bed layer cohesive properties.
        if self.sed_mat_comp.data.info.attrs['use_cohesive_bed_layers']:
            cohesive = self._get_material_bed_layer_cohesive(material, bed_layer_id_to_new_id)
            self.sim.model_control.sediment_properties.global_cohesive_bed = cohesive

        # Set the bed layers.
        bed_layers = self._get_material_bed_layers(material, bed_layer_id_to_new_id)
        self.sim.model_control.sediment_properties.global_bed_layers = bed_layers

        # Set the sediment transport constituent grain size fractions per bed layer.
        constituents = self._get_material_grain_fractions(
            material, 0, bed_layer_id_to_new_id, constituent_old_id_to_new_id
        )
        self.sim.model_control.sediment_properties.bed_layer_grain_fractions = constituents

        # Set the consolidation
        consolidation = self._get_material_consolidation(material)
        self.sim.model_control.sediment_properties.global_consolidation = consolidation
        return bed_layer_id_to_new_id

    @staticmethod
    def _get_material_bed_layers(material, bed_layer_id_to_new_id, material_id=None):
        """Gets the bed layers from the sediment material component.

        Args:
            material (SedimentMaterial): The material to get bed layers from.
            bed_layer_id_to_new_id (dict): A dictionary of old bed layer ids to new.
            material_id (int): The new material id.
        """
        bed_layers = material.bed_layers.drop(
            ['porosity', 'critical_shear', 'erosion_constant', 'erosion_exponent'], axis=1
        )
        bed_layers.replace({'layer_id': bed_layer_id_to_new_id}, inplace=True)
        bed_layers = bed_layers.rename(columns={'layer_id': 'BED_LAYER_ID', 'thickness': 'THICKNESS'})
        if material_id is not None:
            bed_layers.insert(0, 'MATERIAL_ID', [material_id for _ in material.bed_layers.index])
            bed_layers.insert(0, 'CARD', ['MP SBM' for _ in material.bed_layers.index])
        else:
            bed_layers.insert(0, 'CARD', ['MP SBA' for _ in material.bed_layers.index])
        return bed_layers

    @staticmethod
    def _get_material_bed_layer_cohesive(material, bed_layer_id_to_new_id, material_id=None):
        """Gets the cohesive properties of the material bed layers.

        Args:
            material (SedimentMaterial): The material to get bed layer cohesive values from.
            bed_layer_id_to_new_id (dict): A dictionary of old bed layer ids to new.
            material_id (int): The new material id.

        Returns:
            A DataFrame of cohesive values to be appended.
        """
        cohesive = material.bed_layers.drop(['thickness'], axis=1)
        cohesive.replace({'layer_id': bed_layer_id_to_new_id}, inplace=True)
        cohesive = cohesive.rename(
            columns={
                'layer_id': 'BED_LAYER_ID',
                'porosity': 'POROSITY',
                'critical_shear': 'CRITICAL_SHEAR',
                'erosion_constant': 'EROSION_CONSTANT',
                'erosion_exponent': 'EROSION_EXPONENT'
            }
        )
        if material_id is not None:
            cohesive.insert(1, 'MATERIAL_ID', [int(material_id) for _ in material.bed_layers.index])
            cohesive.insert(0, 'CARD', ['CBM' for _ in material.bed_layers.index])
            cohesive.MATERIAL_ID = cohesive.MATERIAL_ID.astype('int64')
        else:
            cohesive.insert(0, 'CARD', ['CBA' for _ in material.bed_layers.index])
        cohesive.insert(0, 'MP', ['MP' for _ in material.bed_layers.index])
        return cohesive

    @staticmethod
    def _get_material_grain_fractions(material, material_id, bed_layer_id_to_new_id, constituent_old_id_to_new_id):
        """Gets a grain fraction dataframe from the material.

        Args:
            material (SedimentMaterial): The material to get bed layer grain fractions from.
            material_id (int): The new material id.
            bed_layer_id_to_new_id (dict): A dictionary of old bed layer ids to new.
            constituent_old_id_to_new_id (dict): A dictionary of old constituent ids to new.
        """
        constituents = copy.deepcopy(material.constituents)
        constituents.replace({'layer_id': bed_layer_id_to_new_id}, inplace=True)
        constituents.replace({'constituent_id': constituent_old_id_to_new_id}, inplace=True)
        constituents = constituents.rename(
            columns={
                'layer_id': 'BED_LAYER_ID',
                'constituent_id': 'CONSTITUENT_ID',
                'fraction': 'FRACTION'
            }
        )
        constituents.insert(0, 'MATERIAL_ID', [material_id for _ in constituents.index])
        constituents.MATERIAL_ID = constituents.MATERIAL_ID.astype('int64')
        constituents.reset_index(drop=True, inplace=True)
        return constituents

    @staticmethod
    def _get_material_consolidation(material, material_id=None):
        """Gets a consolidation dataframe from the material.

        Args:
            material (SedimentMaterial): The material to get consolidation from.
            material_id (int): The new material id.

        Returns:
            A DataFrame that can be appended to an adhparam DataFrame for consolidation.
        """
        old_time_ids = material.consolidation['time_id'].values.tolist()
        time_id_to_new_id = {time_id: idx + 1 for idx, time_id in enumerate(old_time_ids)}
        consolidation = material.consolidation
        consolidation.replace({'time_id': time_id_to_new_id}, inplace=True)
        consolidation = consolidation.rename(
            columns={
                'time_id': 'TIME_ID',
                'elapsed_time': 'ELAPSED_TIME',
                'porosity': 'porosity',
                'critical_shear': 'CRITICAL_SHEAR',
                'erosion_constant': 'EROSION_CONSTANT',
                'erosion_exponent': 'EROSION_EXPONENT'
            }
        )
        if material_id is not None:
            consolidation.insert(0, 'MATERIAL_ID', [material_id for _ in old_time_ids])
            consolidation.insert(0, 'CARD', ['CPM' for _ in old_time_ids])
        else:
            consolidation.insert(0, 'CARD', ['CPA' for _ in old_time_ids])
        consolidation.insert(0, 'MP', ['MP' for _ in old_time_ids])
        return consolidation

    def _add_weirs(self, structures, structures_old_id_to_new_id):
        # weirs
        new_weir_data = []
        for row_data in structures.data.bc.weirs.itertuples():
            weir_id = int(row_data.WRS_NUMBER)
            up_id = int(row_data.S_UPSTREAM)
            down_id = int(row_data.S_DOWNSTREAM)
            edge_up_id = int(row_data.WS_UPSTREAM)
            edge_down_id = int(row_data.WS_DOWNSTREAM)
            length = float(row_data.LENGTH)
            crest = float(row_data.CREST_ELEV)
            height = float(row_data.HEIGHT)
            new_weir_data.append(
                [
                    weir_id, structures_old_id_to_new_id[up_id], structures_old_id_to_new_id[down_id],
                    structures_old_id_to_new_id[edge_up_id], structures_old_id_to_new_id[edge_down_id], length, crest,
                    height
                ]
            )
        weir_df = pd.DataFrame(
            data=new_weir_data,
            columns=[
                'WRS_NUMBER', 'S_UPSTREAM', 'S_DOWNSTREAM', 'WS_UPSTREAM', 'WS_DOWNSTREAM', 'LENGTH', 'CREST_ELEV',
                'HEIGHT'
            ]
        )
        self.sim.boundary_conditions.weirs = pd.concat([self.sim.boundary_conditions.weirs, weir_df])

    def _add_flap_gates(self, structures, structures_old_id_to_new_id):
        # flap gates
        new_flap_data = []
        for row_data in structures.data.bc.flap_gates.itertuples():
            gate_id = int(row_data.FGT_NUMBER)
            user = int(row_data.USER)
            up_id = int(row_data.S_UPSTREAM)
            down_id = int(row_data.S_DOWNSTREAM)
            edge_up_id = int(row_data.FS_UPSTREAM)
            edge_down_id = int(row_data.FS_DOWNSTREAM)
            coef_a = float(row_data.COEF_A)
            coef_b = float(row_data.COEF_B)
            coef_c = float(row_data.COEF_C)
            coef_d = float(row_data.COEF_D)
            coef_e = float(row_data.COEF_E)
            coef_f = float(row_data.COEF_F)
            length = float(row_data.LENGTH)
            new_flap_data.append(
                [
                    gate_id, user, structures_old_id_to_new_id[up_id], structures_old_id_to_new_id[down_id],
                    structures_old_id_to_new_id[edge_up_id], structures_old_id_to_new_id[edge_down_id], coef_a, coef_b,
                    coef_c, coef_d, coef_e, coef_f, length
                ]
            )
        flap_df = pd.DataFrame(
            data=new_flap_data,
            columns=[
                'FGT_NUMBER', 'USER', 'S_UPSTREAM', 'S_DOWNSTREAM', 'FS_UPSTREAM', 'FS_DOWNSTREAM', 'COEF_A', 'COEF_B',
                'COEF_C', 'COEF_D', 'COEF_E', 'COEF_F', 'LENGTH'
            ]
        )
        self.sim.boundary_conditions.flap_gates = pd.concat([self.sim.boundary_conditions.flap_gates, flap_df])

    def _add_sluice_gates(self, structures, structures_old_id_to_new_id, structure_time_to_new_time):
        # sluice gates
        new_sluice_data = []
        for row_data in structures.data.bc.weirs.itertuples():
            sluice_id = int(row_data.SLS_NUMBER)
            up_id = int(row_data.S_UPSTREAM)
            down_id = int(row_data.S_DOWNSTREAM)
            edge_up_id = int(row_data.SS_UPSTREAM)
            edge_down_id = int(row_data.SS_DOWNSTREAM)
            length = float(row_data.LENGTH)
            opening = int(row_data.TS_OPENING)
            new_sluice_data.append(
                [
                    sluice_id, structures_old_id_to_new_id[up_id], structures_old_id_to_new_id[down_id],
                    structures_old_id_to_new_id[edge_up_id], structures_old_id_to_new_id[edge_down_id], length,
                    structure_time_to_new_time[opening]
                ]
            )
        sluice_df = pd.DataFrame(
            data=new_sluice_data,
            columns=[
                'SLS_NUMBER', 'S_UPSTREAM', 'S_DOWNSTREAM', 'SS_UPSTREAM', 'SS_DOWNSTREAM', 'LENGTH', 'TS_OPENING'
            ]
        )
        self.sim.boundary_conditions.sluice_gates = pd.concat([self.sim.boundary_conditions.sluice_gates, sluice_df])

    def _load_geometry(self, materials):
        """Loads the geometry into the param object.

        Args:
            materials (dict): A dictionary of cell id to material id.
        """
        pts = self._grid.ugrid.locations
        node_list = []
        for index in range(len(pts)):
            node_list.append(['ND', index + 1, pts[index][0], pts[index][1], pts[index][2]])
        self.sim.mesh.name = self.sim_name
        self.sim.mesh.nodes = pd.DataFrame(data=node_list, columns=['CARD', 'ID', 'X', 'Y', 'Z'])
        cell_stream = self._grid.ugrid.cellstream
        cell_id = 0
        elements = []
        for index in range(0, len(cell_stream), 5):
            mat_id = materials[cell_id]  # use 0-based
            cell_id += 1
            try:
                # change the node index from 0-based to 1-based
                elem = [
                    'E3T', cell_id, cell_stream[index + 2] + 1, cell_stream[index + 3] + 1, cell_stream[index + 4] + 1,
                    mat_id
                ]
                elements.append(elem)
            except Exception:
                pass
        if elements:
            self.sim.mesh.elements = pd.DataFrame(
                data=elements, columns=['CARD', 'ID', 'NODE_0', 'NODE_1', 'NODE_2', 'MATERIAL_ID']
            )

    def _set_hotstart_datasets(self):
        """Sets the hotstart datasets into the adhparam object."""
        num_cells = self._grid.ugrid.cell_count
        for name, dataset in self._hotstart_datsets.items():
            dset = HotStartDataSet()
            dset.name = name
            values = dataset.values[0]
            if dataset.num_components == 1:
                dset.values = pd.DataFrame(data=values)
            elif dataset.num_components == 2:
                values_col = {'x': [pair[0] for pair in values], 'y': [pair[1] for pair in values]}
                dset.values = pd.DataFrame(data=values_col)
            dset.number_of_cells = num_cells
            self.hotstarts.append(dset)

    def save_xms_data(self, base_path='../xms_data'):
        """Saves simulation data to the specified file path.

        Args:
            base_path (str): The base directory where data will be saved.
        """
        components = {
            'sim_component': self._sim_comp,
            'bc_component': self._bc_comp,
            'material_component': self._mat_comp,
            'sediment_material_component': self._sed_mat_comp,
            'transport_component': self._trans_comp,
            'sediment_constituents_component': self._sed_trans_comp,
        }
        coverages = {
            'bc_coverage': self._bc_cov,
            'material_coverage': self._mat_cov,
            'output_coverage': self._output_cov,
            'sediment_material_coverage': self._sed_mat_cov,
        }
        save_xms_data(base_path, components, coverages, self.xms_data)


def snap_materials(
        grid: Grid, mat_comp: MaterialConceptualComponent, sed_mat_comp: SedimentMaterialConceptualComponent,
        mat_cov: Coverage, sed_mat_cov: Coverage, mat_id_to_string_id: dict[int, list],
        sed_mat_id_to_string_id: dict[int, list],
        next_string_id: int, string_cards: list, mat_id_to_mat_sed: dict[int, tuple[int, int]]
):
    """Snaps the materials and sediments to the grid geometry.

    Args:
        grid: The grid associated with the object.
        mat_comp: Material component.
        sed_mat_comp: Sediment material component.
        mat_cov: Material coverage.
        sed_mat_cov: Sediment material coverage.
        mat_id_to_string_id: A dictionary mapping material ID to a list of string IDs.
        sed_mat_id_to_string_id: A dictionary mapping sediment ID to a list of string IDs.
        next_string_id (int): The next available string ID to assign.
        string_cards (list): A list to store generated string cards.
        mat_id_to_mat_sed: Mapping from new material ID to original material ID and sediment material ID.

    Returns:
        A tuple containing:
            - A list of string IDs for each cell in the geometry.
            - The updated next string ID.
    """

    def snap_component_to_grid(component, coverage, index):
        """Snaps component polygons to the grid and updates mat_combo_cell.

        Args:
            component: The material or sediment component (mat_comp or sed_mat_comp).
            coverage: The coverage object (mat_cov or sed_mat_cov).
            index: 0 for materials, 1 for sediments.
        """
        if not component:
            return

        # Retrieve polygons mapped to component IDs
        polygon_mapping = component.comp_to_xms.get(component.cov_uuid, {}).get(TargetType.polygon, {})

        snap = SnapPolygon()
        snap.set_grid(grid, True)
        polygons = coverage.polygons if coverage else []
        snap.add_polygons(polygons)

        # Assign polygons to grid cells for each component ID
        if index == 0:
            mat_id_keys = component.data.materials.material_properties.keys()
        else:
            mat_id_keys = component.data.materials.keys()
        for mat_id in mat_id_keys:
            polygon_ids = polygon_mapping.get(mat_id, [])
            for polygon_id in polygon_ids:
                cells = snap.get_cells_in_polygon(polygon_id)
                for cell in cells:
                    mat_combo_cell[cell][index] = mat_id

    def generate_string_id(mat_id, sed_mat_id):
        """Generates or retrieves a string ID for a given material and sediment combination.

        Args:
            mat_id: The material ID.
            sed_mat_id: The sediment ID.

        Returns:
            The corresponding string ID for the material and sediment combination.
        """
        nonlocal next_string_id

        if mat_id not in mat_id_to_sed_mat_id:
            mat_id_to_sed_mat_id[mat_id] = {}

        if sed_mat_id not in mat_id_to_sed_mat_id[mat_id]:
            # Assign a new string ID
            mat_string_id = next_string_id
            next_string_id += 1

            # Update material-to-string and sediment-to-string mappings
            mat_id_to_string_id.setdefault(mat_id, []).append(mat_string_id)
            if sed_mat_id != 0:
                sed_mat_id_to_string_id.setdefault(sed_mat_id, []).append(mat_string_id)

            _add_string_card(string_cards, 'MTS', mat_string_id, mat_string_id)
            mat_id_to_sed_mat_id[mat_id][sed_mat_id] = mat_string_id
        else:
            mat_string_id = mat_id_to_sed_mat_id[mat_id][sed_mat_id]

        return mat_string_id

    # Initialize cell assignments
    mat_string_cell = [0 for _ in range(grid.ugrid.cell_count)]
    mat_combo_cell = [[0, 0] for _ in range(grid.ugrid.cell_count)]

    # Snap material and sediment polygons to the grid
    snap_component_to_grid(mat_comp, mat_cov, index=0)  # index 0 for materials
    snap_component_to_grid(sed_mat_comp, sed_mat_cov, index=1)  # index 1 for sediments

    # Initialize a dictionary to track material-to-sediment string mappings
    mat_id_to_sed_mat_id = {}

    # Create a sorted list of unique (material ID, sediment ID) pairs
    unique_mat_sed_pairs = sorted(set((mat_id, sed_mat_id) for mat_id, sed_mat_id in mat_combo_cell if mat_id != 0))

    # Assign string IDs in sorted order
    for mat_id, sed_mat_id in unique_mat_sed_pairs:
        generate_string_id(mat_id, sed_mat_id)

    # Assign the string IDs to each cell
    for cell_idx, (mat_id, sed_mat_id) in enumerate(mat_combo_cell):
        if mat_id != 0:
            mat_string_cell[cell_idx] = mat_id_to_sed_mat_id[mat_id][sed_mat_id]
        else:
            mat_string_cell[cell_idx] = 0

    for mat_index, pair in enumerate(unique_mat_sed_pairs):
        mat_id_to_mat_sed[mat_index + 1] = pair

    return mat_string_cell, next_string_id


def _add_string_card(string_cards, card, card_id, id_0, id_1=None):
    """Adds a string card.

    Args:
        string_cards (list): A list of lists that will become the strings DataFrame.
        card (str): The card text.
        card_id (int): The string id of the card.
        id_0 (int): The geometry node/cell id.
        id_1 (int): The node id, if applicable.
    """
    record = [card, int(card_id), int(id_0), int(id_1) if id_1 else float('NaN')]
    string_cards.append(record)


def _find_nearest_neighbor(rtree_index: rtree.index.Index, query_point) -> int | None:
    """
    Find the nearest neighbor index for a point.

    Args:
        rtree_index: R-tree index structure used for spatial queries.
        query_point: The query point (with coordinates x and y)

    Returns:
        The closest point index or None.
    """
    if len(rtree_index) == 0:
        return None
    x = query_point[0]
    y = query_point[1]
    closest = next(rtree_index.nearest((x, y, x, y), 1))
    return closest


def add_nodal_output(output_coverage, grid_locations, simulation):
    """
    Adds nodal outputs to the simulation.

    Args:
        output_coverage: Coverage object containing nodal output points.
        grid_locations: List of grid node coordinates.
        simulation: Simulation to add nodal output to.
    """
    if output_coverage is not None:

        # Create R-tree index of grid points
        grid_pt_index = rtree.index.Index()
        for i, loc in enumerate(grid_locations):
            x, y = loc[0], loc[1]
            grid_pt_index.insert(i, (x, y, x, y))

        # Get the points
        output_points = output_coverage.get_points(FilterLocation.PT_LOC_DISJOINT)

        # Snap each point to a node (should this ignore duplicate nodes?)
        cards = []
        for pt in output_points:
            node_index = _find_nearest_neighbor(grid_pt_index, (pt.x, pt.y))
            if node_index is not None:
                cards.append(['PRN', node_index + 1, pt.id])

        # Add points to BC data
        labels = simulation.model_control.output_control.nodal_output.columns
        df = pd.DataFrame(cards, columns=labels)
        simulation.model_control.output_control.nodal_output = df
