"""Map Boundary Conditions coverage locations and attributes to the AdH domain."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os
import shutil
import uuid

# 2. Third party modules

# 3. Aquaveo modules
from xms.components.display.display_options_io import (
    read_display_options_from_json, write_display_option_line_locations, write_display_options_to_json
)
from xms.data_objects.parameters import Component, FilterLocation
from xms.guipy.data.category_display_option import CategoryDisplayOption
from xms.guipy.data.category_display_option_list import CategoryDisplayOptionList
from xms.guipy.data.line_style import LineOptions, LinePointStyle, LineStyle
from xms.guipy.data.point_symbol import PointOptions
from xms.guipy.data.target_type import TargetType
from xms.snap.snap_exterior_arc import SnapExteriorArc
from xms.snap.snap_interior_arc import SnapInteriorArc
from xms.snap.snap_point import SnapPoint

# 4. Local modules
from xms.adh.components.mapped_component import MappedComponent
from xms.adh.data.card_info import CardInfo
from xms.adh.gui.widgets.color_list import ColorList
from xms.adh.gui.widgets.friction_widget import FrictionWidget
from xms.adh.mapping.snap_node_string import SnapNodeString


class BcMapper:
    """Class for mapping boundary conditions coverage to a mesh for AdH."""
    def __init__(self, coverage_mapper, generate_snap):
        """Constructor."""
        self._logger = coverage_mapper._logger
        self._wkt = coverage_mapper._wkt
        self._generate_snap = generate_snap
        self._co_grid = coverage_mapper._mesh
        self._active_mesh = coverage_mapper._active_mesh
        self._bc_comp = coverage_mapper._bc_comp
        self._bc_component_file = coverage_mapper._bc_comp.main_file
        self._bc_cov = coverage_mapper._bc_cov
        self._bc_cov_uuid = self._bc_cov.uuid
        self._trans_comp = coverage_mapper._transport_comp
        self._sed_comp = coverage_mapper._sediment_comp
        self._new_comp_unique_name = 'Mapped_Component'
        self._arc_to_edge_snap = {}
        self._arc_to_mid_snap = {}
        self._arc_to_point_snap = {}
        self._points_to_snap = {}
        self._pt_comp_to_xms_ids = None
        self._pt_comp_rows = None
        self._arc_comp_to_xms_ids = None
        self._arc_comp_rows = None
        self._arc_id_to_arc = {}
        self._pt_id_to_pt = {}
        self._ex_arc_snapper = None
        self._in_arc_snapper = None
        self._pt_arc_snapper = None
        self._pt_snapper = None

    def do_map(self):
        """Creates the mapped boundary condition components."""
        do_comps = []
        comps = []
        if self._bc_cov_uuid not in self._bc_comp.comp_to_xms:
            return do_comps, comps
        self._setup_component_id_data()
        self._setup_snappers()
        # For each snap (nb_db, friction) of each arc, I need: a category, locations
        # For each snap (flux, sediment diversion) of each arc, I need: locations
        # For transport snap, I need: for each constituent: (which category, locations) or nothing
        nb_db_arcs = self._get_nb_db_snap()
        friction_arcs = self._get_friction_snap()
        flux_arcs = self._get_flux_snap()
        sed_div_arcs = self._get_sediment_diversion_snap()
        transport_con_arcs, is_sediments = self._get_transport_snap()
        # self._get_grid_points_from_arcs()
        if self._generate_snap:
            bc_comp_path = os.path.dirname(self._bc_component_file)

            if nb_db_arcs:
                comp_path, comp_uuid = self._create_component_folder(bc_comp_path, '')
                self._copy_display_options(bc_comp_path, comp_path, comp_uuid, '')
                for cat, arcs in nb_db_arcs.items():
                    self._create_drawing(cat.replace('.display_ids', ''), arcs, comp_path)
                self._create_component_to_send(comp_path, 'Snapped BC', do_comps, comps)

            if friction_arcs:
                comp_path, comp_uuid = self._create_component_folder(bc_comp_path, '')
                all_friction_categories = []
                for fric in FrictionWidget.friction_types:
                    if fric != 'Off':
                        new_file = fric.replace(' ', '_').lower()
                        all_friction_categories.append([new_file, fric, False])
                        if fric in friction_arcs:
                            self._create_drawing(new_file, friction_arcs[fric], comp_path)
                self._create_display_options(all_friction_categories, comp_path, comp_uuid, '')
                self._create_component_to_send(comp_path, 'Snapped Friction', do_comps, comps)

            if flux_arcs:
                comp_path, comp_uuid = self._create_component_folder(bc_comp_path, '')
                self._create_display_options([['flux', 'Flux', False]], comp_path, comp_uuid, '')
                self._create_drawing('flux', flux_arcs, comp_path)
                self._create_component_to_send(comp_path, 'Snapped Flux', do_comps, comps)

            if sed_div_arcs:
                comp_path, comp_uuid = self._create_component_folder(bc_comp_path, '')
                self._create_display_options([['sed_div', 'Sediment Diversion', False]], comp_path, comp_uuid, '')
                self._create_drawing('sed_div', sed_div_arcs, comp_path)
                self._create_component_to_send(comp_path, 'Snapped Sediment Diversion', do_comps, comps)

            for con in is_sediments.keys():
                if con not in transport_con_arcs:  # TODO: go from constituent, to category, to arc locations
                    continue
                is_sediment = is_sediments[con]
                con_categories = [['natural', 'Natural', False], ['dirichlet', 'Dirichlet', True]]
                if is_sediment:
                    con_categories.append(['equilibrium', 'Equilibrium', True])
                comp_path, comp_uuid = self._create_component_folder(bc_comp_path, '')
                self._create_display_options(con_categories, comp_path, comp_uuid, '')
                for cat in con_categories:
                    self._create_drawing(cat[0], transport_con_arcs[con][cat[1]], comp_path)
                self._create_component_to_send(comp_path, f'Snapped {con}', do_comps, comps)

        return do_comps, comps

    def _create_component_to_send(self, comp_path, display_name, do_comps, comps):
        """Creates a xms.data_objects component for sending to XMS.

        Args:
            comp_path (str): The component's file path.
            display_name (str): The name that will appear in the project explorer of XMS.
            do_comps (list): A list of xms.data_objects components. It Will be appended to.
            comps (list): A list of MappedComponent. It Will be appended to.
        """
        # Create the data_objects component
        main_file = os.path.join(comp_path, 'mapped_display_options.json')
        do_comp = Component(
            main_file=main_file,
            name=display_name,
            unique_name=self._new_comp_unique_name,
            model_name='AdH',
            comp_uuid=os.path.basename(comp_path)
        )
        do_comps.append(do_comp)
        comps.append(MappedComponent(main_file))

    def _create_component_folder(self, bc_comp_path, comp_uuid):
        """Creates the folder for the mapped bc component and copies the display options from the bc coverage.

        Args:
            bc_comp_path (str): The path to the boundary conditions component mainfile.
            comp_uuid (str): The new snap preview component UUID.

        Returns:
            The path to the new component folder and the new component UUID (same as passed in if not None).
        """
        if not comp_uuid:
            comp_uuid = str(uuid.uuid4())  # pragma: no cover
        self._logger.info('Creating component folder')
        comp_path = os.path.join(os.path.dirname(bc_comp_path), comp_uuid)

        if os.path.exists(comp_path):
            shutil.rmtree(comp_path)  # pragma: no cover
        os.mkdir(comp_path)
        return comp_path, comp_uuid

    def _copy_display_options(self, bc_comp_path, comp_path, comp_uuid, display_uuid):
        """Copies the display options of the boundary conditions coverage.

        Args:
            bc_comp_path (str): The path to the boundary conditions component mainfile.
            comp_path (str): The path to the new snap preview component mainfile.
            comp_uuid (str): The new snap preview component UUID.
            display_uuid (str): The new snap preview component display UUID.
        """
        bc_comp_display_file = os.path.join(bc_comp_path, 'bc_arc_display_options.json')
        comp_display_file = os.path.join(comp_path, 'mapped_display_options.json')
        if os.path.isfile(bc_comp_display_file):
            shutil.copyfile(bc_comp_display_file, comp_display_file)
            categories = CategoryDisplayOptionList()  # Generates a random UUID key for the display list
            json_dict = read_display_options_from_json(comp_display_file)
            if not display_uuid:
                json_dict['uuid'] = str(uuid.uuid4())  # pragma: no cover
            else:
                json_dict['uuid'] = display_uuid
            json_dict['comp_uuid'] = comp_uuid
            json_dict['is_ids'] = 0
            categories.from_dict(json_dict)
            categories.projection = {'wkt': self._wkt}

            # Set all snapped arcs to be dashed and thick by default. Keep current color.
            for category in categories.categories:
                if isinstance(category.options, PointOptions):
                    line_options = LineOptions()
                    line_options.style = LineStyle.DASHEDLINE
                    line_options.width = 4
                    line_options.color = category.options.color
                    line_options.end_point = category.options
                    line_options.vertex = category.options
                    line_options.point_style = LinePointStyle.all_points
                    line_options.end_point.color = line_options.color
                    line_options.vertex.color = line_options.color
                    category.options = line_options
                else:
                    category.options.style = LineStyle.DASHEDLINE
                    category.options.width = 4

            write_display_options_to_json(comp_display_file, categories)
        else:
            self._logger.info('Could not find bc_display_options.json file')  # pragma: no cover

    def _create_display_options(self, new_categories, comp_path, comp_uuid, display_uuid):
        """Creates the display options.

        Args:
            new_categories (list): A list of category filenames and names and whether this is a nodestring.
            comp_path (str): The path to the new snap preview component mainfile.
            comp_uuid (str): The new snap preview component UUID.
            display_uuid (str): The new snap preview component display UUID.
        """
        comp_display_file = os.path.join(comp_path, 'mapped_display_options.json')

        categories = CategoryDisplayOptionList()  # Generates a random UUID key for the display list
        if display_uuid:
            categories.uuid = display_uuid
        categories.comp_uuid = comp_uuid
        categories.is_ids = False
        categories.target_type = TargetType.arc
        for idx, (filename, name, is_nodestring) in enumerate(new_categories):
            cat = CategoryDisplayOption()
            cat.file = f'{filename}.display_ids'
            cat.description = name
            cat.is_unassigned_category = False
            cat.label_on = False
            cat.options = LineOptions()
            cat.options.style = LineStyle.DASHEDLINE
            cat.options.color = ColorList.colors[idx]
            if is_nodestring:
                cat.options.point_style = LinePointStyle.all_points
                cat.options.end_point.color = cat.options.color
                cat.options.vertex.color = cat.options.color
            else:
                cat.options.width = 4
            categories.categories.append(cat)
        categories.projection = {'wkt': self._wkt}
        write_display_options_to_json(comp_display_file, categories)

    def _create_drawing(self, disp_name, arc_grid_points, comp_path):
        """Uses cell ids to get cell point coords to draw lines for bc mapped to cells for a single category.

        Args:
            disp_name (str): The name for the category display ids file.
            arc_grid_points (list): A list of lists of snapped points for each arc.
            comp_path (str): The path to the new component.
        """
        filename = os.path.join(comp_path, f'{disp_name}.display_ids')
        write_display_option_line_locations(filename, arc_grid_points)

    def _get_nb_db_snap(self):
        """Get the snapped boundary condition locations.

        Returns:
            A dictionary, key per category, of a list of a list of snapped arc locations per bc arc.
        """
        bc_cat_arcs = {}

        # Get the arc locations that will make a edgestring/midstring/nodestring.
        if self._arc_comp_rows is not None:
            bc_arc_rows = self._arc_comp_rows.loc[self._arc_comp_rows['BC_ID'] > 0]
            bc_comp_ids = bc_arc_rows['COMP_ID'].tolist()
            bc_ids = bc_arc_rows['BC_ID'].tolist()
            bc_df = self._bc_comp.data.bc.solution_controls
            for comp_id, bc_id in zip(bc_comp_ids, bc_ids):
                bc_row = bc_df.loc[bc_df['STRING_ID'] == bc_id]
                bc_type = 'Off'
                snap = 'Edgestring snap'
                if not bc_row.empty:
                    card_type = bc_row.CARD.iloc[0]
                    bc_card_type = bc_row.CARD_2.iloc[0]
                    bc_type = CardInfo.arc_card_to_id_file[bc_card_type]
                    if card_type == 'DB':
                        snap_df = self._bc_comp.data.snap_types.loc[self._bc_comp.data.snap_types['ID'] == bc_id]
                        if not snap_df.empty:
                            snap = snap_df.SNAP.iloc[0]
                else:
                    sd_df = self._bc_comp.data.bc.stage_discharge_boundary
                    bc_row = sd_df.loc[sd_df['S_ID'] == bc_id]
                    if not bc_row.empty:
                        bc_type = CardInfo.arc_card_to_id_file['SDR']
                    else:
                        bc_row = self._bc_comp.data.nb_out.loc[self._bc_comp.data.nb_out['BC_ID'] == bc_id]
                        if not bc_row.empty:
                            if bc_row.OUT_COMP_ID.iloc[0] == comp_id:
                                bc_type = CardInfo.arc_card_to_id_file['OUT outflow']
                            elif bc_row.IN_COMP_ID.iloc[0] == comp_id:
                                bc_type = CardInfo.arc_card_to_id_file['OUT inflow']
                if bc_type == 'Off':
                    continue
                if bc_type not in bc_cat_arcs:
                    bc_cat_arcs[bc_type] = []
                # Create a snapped arc per arc for this bc id.
                for arc_id in self._arc_comp_to_xms_ids[comp_id]:
                    if snap == 'Edgestring snap':
                        bc_cat_arcs[bc_type].append(self._get_edgestring(arc_id))
                    elif snap == 'Midstring snap':
                        bc_cat_arcs[bc_type].append(self._get_midstring(arc_id))
                    else:
                        bc_cat_arcs[bc_type].append(self._get_pointstring(arc_id))

        # Get the point sets that will make a nodestring.
        if self._pt_comp_rows is not None:
            bc_pt_rows = self._pt_comp_rows.loc[self._pt_comp_rows['BC_ID'] > 0]
            bc_comp_ids = bc_pt_rows['COMP_ID'].tolist()
            bc_ids = bc_pt_rows['BC_ID'].tolist()
            pt_bc_to_comp_ids = {}
            # Group points under the bc_id identified set.
            for comp_id, bc_id in zip(bc_comp_ids, bc_ids):
                if bc_id not in pt_bc_to_comp_ids:
                    pt_bc_to_comp_ids[bc_id] = []
                pt_bc_to_comp_ids[bc_id].append(comp_id)
            # Get the display type and create the snapped locations.
            for bc_id, comp_ids in pt_bc_to_comp_ids.items():
                pt_locs = [
                    self._pt_id_to_pt[pt_id] for comp_id in comp_ids for pt_id in self._pt_comp_to_xms_ids[comp_id]
                ]
                bc_row = bc_df.loc[bc_df['STRING_ID'] == bc_id]
                bc_type = 'Off'
                if not bc_row.empty:
                    bc_card_type = bc_row.CARD_2.iloc[0]
                    # We don't snap wind points
                    if bc_card_type == 'WND':
                        continue
                    # Despite using points, we will consolidate these into arc display options.
                    bc_type = CardInfo.arc_card_to_id_file[bc_card_type]
                if bc_type == 'Off':
                    continue
                if bc_type not in bc_cat_arcs:
                    bc_cat_arcs[bc_type] = []
                snap_output = self._pt_snapper.get_snapped_points(pt_locs)
                points = [item for sublist in snap_output['location'] for item in sublist]
                bc_cat_arcs[bc_type].append(points)
        return bc_cat_arcs

    def _get_friction_snap(self):
        """Get the snapped boundary condition friction locations.

        Returns:
            A dictionary, key per category, of a list of a list of snapped arc locations per friction arc.
        """
        friction_cat_arcs = {}
        if self._arc_comp_rows is None:
            return friction_cat_arcs

        fric_arc_rows = self._arc_comp_rows.loc[self._arc_comp_rows['FRICTION_ID'] > 0]
        fric_comp_ids = fric_arc_rows['COMP_ID'].tolist()
        fric_ids = fric_arc_rows['FRICTION_ID'].tolist()
        fric_df = self._bc_comp.data.bc.friction_controls
        for comp_id, fric_id in zip(fric_comp_ids, fric_ids):
            fric_row = fric_df.loc[fric_df['STRING_ID'] == fric_id]
            if fric_row.empty:
                continue
            is_mid = fric_row.REAL_05.iloc[0] != 0
            fric_card_type = fric_row.CARD_2.iloc[0]
            if not fric_card_type:
                continue
            fric_type = FrictionWidget.card_to_type[fric_card_type]
            if fric_type not in friction_cat_arcs:
                friction_cat_arcs[fric_type] = []
            for arc_id in self._arc_comp_to_xms_ids[comp_id]:
                if is_mid == 'Midstring':
                    friction_cat_arcs[fric_type].append(self._get_midstring(arc_id))
                else:
                    friction_cat_arcs[fric_type].append(self._get_edgestring(arc_id))
        return friction_cat_arcs

    def _get_flux_snap(self):
        """Gets the snapped flux arcs locations.

        Returns:
            A list of a list of snapped arc locations per flux arc.
        """
        flux_arcs = []
        if self._arc_comp_rows is None:
            return flux_arcs

        flux_arc_rows = self._arc_comp_rows.loc[self._arc_comp_rows['FLUX_ID'] > 0]
        flux_comp_ids = flux_arc_rows['COMP_ID'].tolist()
        flux_ids = flux_arc_rows['FLUX_ID'].tolist()
        for comp_id, flux_id in zip(flux_comp_ids, flux_ids):
            flux_row = self._bc_comp.data.flux.loc[self._bc_comp.data.flux['ID'] == flux_id]
            if flux_row.empty:
                continue
            if not flux_row.IS_FLUX.iloc[0]:
                continue
            is_edge = flux_row.EDGESTRING.iloc[0]
            is_mid = flux_row.MIDSTRING.iloc[0]
            for arc_id in self._arc_comp_to_xms_ids[comp_id]:
                if is_edge:
                    flux_arcs.append(self._get_edgestring(arc_id))
                if is_mid:
                    flux_arcs.append(self._get_midstring(arc_id))
        return flux_arcs

    def _get_sediment_diversion_snap(self):
        """Get the snapped sediment diversion locations.

        Returns:
            A list of a list of snapped arc locations per sediment diversion arc.
        """
        sed_div_arcs = []
        if self._arc_comp_rows is None:
            return sed_div_arcs

        div_arc_rows = self._arc_comp_rows.loc[self._arc_comp_rows['DIVERSION_ID'] > 0]
        div_comp_ids = div_arc_rows['COMP_ID'].tolist()
        div_ids = div_arc_rows['DIVERSION_ID'].tolist()
        for comp_id, div_id in zip(div_comp_ids, div_ids):
            div_row = \
                self._bc_comp.data.sediment_diversions.loc[self._bc_comp.data.sediment_diversions['DIV_ID'] == div_id]
            if div_row.empty:
                continue
            # raise Exception(f'{div_row}\n')
            snap = div_row.SNAPPING.iloc[0]
            for arc_id in self._arc_comp_to_xms_ids[comp_id]:
                if snap == 'Edgestring':
                    sed_div_arcs.append(self._get_edgestring(arc_id))
                elif snap == 'Midstring':
                    sed_div_arcs.append(self._get_midstring(arc_id))
        return sed_div_arcs

    def _get_transport_snap(self):
        """Get the snapped transport locations.

        Returns:
            A dictionary of constituents to a dictionary of categories (Natural, Dirichlet, Equilibrium) to a
            list of a list of snapped arc locations per transport arc. Also, a dictionary of constituents to a bool
            that is True if the constituent is a sediment constituent.
        """
        transport_con_arcs = {}
        is_sediments = {}

        if not self._trans_comp and not self._sed_comp:
            return transport_con_arcs, is_sediments

        # Get the constituent names
        all_names = []
        tran_con_id_to_name = {}
        sed_con_id_to_name = {}
        if self._trans_comp:
            tran_con_id = self._trans_comp.data.user_constituents.ID.data.tolist()
            tran_con_name = self._trans_comp.data.user_constituents.NAME.data.tolist()
            all_names = tran_con_name
            tran_con_id_to_name = {1: 'Salinity', 2: 'Temperature', 3: 'Vorticity'}
            tran_con_id_to_name.update({con_id: con_name for con_id, con_name in zip(tran_con_id, tran_con_name)})
        if self._sed_comp:
            sed_con_id = self._sed_comp.data.param_control.sand.ID.tolist()
            sed_con_name = self._sed_comp.data.param_control.sand.NAME.tolist()
            sed_con_id.extend(self._sed_comp.data.param_control.clay.ID.tolist())
            sed_con_name.extend(self._sed_comp.data.param_control.clay.NAME.tolist())
            # In case there is a transport constituent and sediment transport constituent with the same name,
            # make the sediment name something different.
            for idx, sed_name in enumerate(sed_con_name):
                if sed_name in all_names:
                    sed_con_name[idx] = f'{sed_name} (sediment)'
            sed_con_id_to_name = {con_id: con_name for con_id, con_name in zip(sed_con_id, sed_con_name)}

        trans_df = self._bc_comp.data.transport_assignments
        use_trans_df = self._bc_comp.data.uses_transport
        sed_df = self._bc_comp.data.sediment_assignments
        use_sed_df = self._bc_comp.data.uses_sediment

        # Get the arc locations that will make a edgestring/midstring/nodestring.
        if self._arc_comp_rows is not None:
            trans_arc_rows = self._arc_comp_rows.loc[self._arc_comp_rows['TRANSPORT_ID'] > 0]
            trans_comp_ids = trans_arc_rows['COMP_ID'].tolist()
            trans_ids = trans_arc_rows['TRANSPORT_ID'].tolist()
            for comp_id, trans_id in zip(trans_comp_ids, trans_ids):
                use_tran_row = use_trans_df.loc[use_trans_df['TRAN_ID'] == trans_id]
                use_sed_row = use_sed_df.loc[use_sed_df['TRAN_ID'] == trans_id]
                use_tran = False
                use_sed = False
                tran_row = None
                sed_row = None
                if not use_tran_row.empty:
                    use_tran = use_tran_row.USES_TRANSPORT.iloc[0]
                if not use_sed_row.empty:
                    use_sed = use_sed_row.USES_SEDIMENT.iloc[0]
                if not use_tran and not use_sed:
                    continue
                if use_tran:
                    tran_row = trans_df.loc[trans_df['TRAN_ID'] == trans_id]
                if use_sed:
                    sed_row = sed_df.loc[sed_df['TRAN_ID'] == trans_id]

                if tran_row is not None and not tran_row.empty:
                    con_ids = tran_row.CONSTITUENT_ID.tolist()
                    types = tran_row.TYPE.tolist()
                    snaps = tran_row.SNAPPING.tolist()
                    for con_id, con_type, snap in zip(con_ids, types, snaps):
                        # con_id to constituent name
                        con_name = tran_con_id_to_name[con_id]
                        if con_name not in transport_con_arcs:
                            transport_con_arcs[con_name] = {'Natural': [], 'Dirichlet': []}
                            is_sediments[con_name] = False
                        # con_type to display category (same text)
                        # snap to edge/mid/node if DB or EQ, edge if NB
                        snap_method = self._get_edgestring
                        if con_type == 'Dirichlet':
                            if snap == 'Midstring snap':
                                snap_method = self._get_midstring
                            elif snap == 'Point snap':
                                snap_method = self._get_pointstring
                        for arc_id in self._arc_comp_to_xms_ids[comp_id]:
                            transport_con_arcs[con_name][con_type].append(snap_method(arc_id))
                if sed_row is not None and not sed_row.empty:
                    con_ids = sed_row.CONSTITUENT_ID.tolist()
                    types = sed_row.TYPE.tolist()
                    snaps = sed_row.SNAPPING.tolist()
                    for con_id, con_type, snap in zip(con_ids, types, snaps):
                        # con_id to constituent name
                        con_name = sed_con_id_to_name[con_id]
                        if con_name not in transport_con_arcs:
                            transport_con_arcs[con_name] = {'Natural': [], 'Dirichlet': [], 'Equilibrium': []}
                            is_sediments[con_name] = True
                        # con_type to display category (same text)
                        # snap to edge/mid/node if DB or EQ, edge if NB
                        snap_method = self._get_edgestring
                        if con_type in ['Dirichlet', 'Equilibrium']:
                            if snap == 'Midstring snap':
                                snap_method = self._get_midstring
                            elif snap == 'Point snap':
                                snap_method = self._get_pointstring
                        for arc_id in self._arc_comp_to_xms_ids[comp_id]:
                            transport_con_arcs[con_name][con_type].append(snap_method(arc_id))

        # Get the point sets that will make a nodestring.
        if self._pt_comp_rows is not None:
            tran_pt_rows = self._pt_comp_rows.loc[self._pt_comp_rows['TRANSPORT_ID'] > 0]
            trans_comp_ids = tran_pt_rows['COMP_ID'].tolist()
            trans_ids = tran_pt_rows['TRANSPORT_ID'].tolist()
            pt_tran_to_comp_ids = {}
            # Group points under the transport_id identified set.
            for comp_id, trans_id in zip(trans_comp_ids, trans_ids):
                if trans_id not in pt_tran_to_comp_ids:
                    pt_tran_to_comp_ids[trans_id] = []
                pt_tran_to_comp_ids[trans_id].append(comp_id)
            # Get the display type and create the snapped locations.
            for trans_id, comp_ids in pt_tran_to_comp_ids.items():
                use_tran_row = use_trans_df.loc[use_trans_df['TRAN_ID'] == trans_id]
                use_sed_row = use_sed_df.loc[use_sed_df['TRAN_ID'] == trans_id]
                use_tran = False
                use_sed = False
                tran_row = None
                sed_row = None
                if not use_tran_row.empty:
                    use_tran = use_tran_row.USES_TRANSPORT.iloc[0]
                if not use_sed_row.empty:
                    use_sed = use_sed_row.USES_SEDIMENT.iloc[0]
                if not use_tran and not use_sed:
                    continue
                if use_tran:
                    tran_row = trans_df.loc[trans_df['TRAN_ID'] == trans_id]
                if use_sed:
                    sed_row = sed_df.loc[sed_df['TRAN_ID'] == trans_id]
                pt_locs = [
                    self._pt_id_to_pt[pt_id] for comp_id in comp_ids for pt_id in self._pt_comp_to_xms_ids[comp_id]
                ]
                snap_output = self._pt_snapper.get_snapped_points(pt_locs)
                points = [item for sublist in snap_output['location'] for item in sublist]
                if tran_row is not None and not tran_row.empty:
                    con_ids = tran_row.CONSTITUENT_ID.tolist()
                    types = tran_row.TYPE.tolist()
                    for con_id, con_type in zip(con_ids, types):
                        # con_id to constituent name
                        con_name = tran_con_id_to_name[con_id]
                        if con_name not in transport_con_arcs:
                            transport_con_arcs[con_name] = {'Natural': [], 'Dirichlet': []}
                            is_sediments[con_name] = False
                        # con_type to display category (same text)
                        transport_con_arcs[con_name][con_type].append(points)
                if sed_row is not None and not sed_row.empty:
                    con_ids = sed_row.CONSTITUENT_ID.tolist()
                    types = sed_row.TYPE.tolist()
                    for con_id, con_type in zip(con_ids, types):
                        # con_id to constituent name
                        con_name = sed_con_id_to_name[con_id]
                        if con_name not in transport_con_arcs:
                            transport_con_arcs[con_name] = {'Natural': [], 'Dirichlet': [], 'Equilibrium': []}
                            is_sediments[con_name] = True
                        # con_type to display category (same text)
                        transport_con_arcs[con_name][con_type].append(points)
        return transport_con_arcs, is_sediments

    def _get_edgestring(self, arc_id):
        """Gets snapped locations of an edgesting from an arc.

        Args:
            arc_id (int): The arc feature id.

        Returns:
            A list of locations the arc snaps to following edgestring constraints.
        """
        if arc_id not in self._arc_to_edge_snap:
            arc = self._arc_id_to_arc[arc_id]
            snap_output = self._ex_arc_snapper.get_snapped_points(arc)
            points = [item for sublist in snap_output['location'] for item in sublist]
            self._arc_to_edge_snap[arc_id] = points
        return self._arc_to_edge_snap[arc_id]

    def _get_midstring(self, arc_id):
        """Gets snapped locations of a midsting from an arc.

        Midstrings cannot cross holes in the domain.

        Args:
            arc_id (int): The arc feature id.

        Returns:
            A list of locations the arc snaps to following midstring constraints.
        """
        if arc_id not in self._arc_to_mid_snap:
            arc = self._arc_id_to_arc[arc_id]
            snap_output = self._in_arc_snapper.get_snapped_points(arc)
            points = [item for sublist in snap_output['location'] for item in sublist]
            self._arc_to_mid_snap[arc_id] = points
        return self._arc_to_mid_snap[arc_id]

    def _get_pointstring(self, arc_id):
        """Gets snapped locations of a nodesting from an arc.

        The arc end points and vertices will each be individually snapped to the closest node of the geometry.

        Args:
            arc_id (int): The arc feature id.

        Returns:
            A list of locations the arc snaps to following nodestring constraints.
        """
        if arc_id not in self._arc_to_point_snap:
            arc = self._arc_id_to_arc[arc_id]
            snap_output = self._pt_arc_snapper.get_snapped_points(arc)
            points = [item for sublist in snap_output['location'] for item in sublist]
            self._arc_to_point_snap[arc_id] = points
        return self._arc_to_point_snap[arc_id]

    def _setup_component_id_data(self):
        """Sets up dictionaries and dataframes based on existing component ids."""
        comp_id_to_ids = self._bc_comp.data.comp_id_to_ids
        cov_comp_to_xms = self._bc_comp.comp_to_xms[self._bc_cov_uuid]
        if TargetType.arc in cov_comp_to_xms:
            self._arc_comp_to_xms_ids = cov_comp_to_xms[TargetType.arc]
            arc_comp_ids = list(self._arc_comp_to_xms_ids.keys())
            self._arc_comp_rows = comp_id_to_ids.loc[comp_id_to_ids['COMP_ID'].isin(arc_comp_ids)]
            arcs = self._bc_cov.arcs
            for arc in arcs:
                self._arc_id_to_arc[arc.id] = arc

        if TargetType.point in cov_comp_to_xms:
            self._pt_comp_to_xms_ids = cov_comp_to_xms[TargetType.point]
            pt_comp_ids = list(self._pt_comp_to_xms_ids.keys())
            self._pt_comp_rows = comp_id_to_ids.loc[comp_id_to_ids['COMP_ID'].isin(pt_comp_ids)]
            points = self._bc_cov.get_points(FilterLocation.PT_LOC_DISJOINT)
            for pt in points:
                self._pt_id_to_pt[pt.id] = pt

    def _setup_snappers(self):
        """Sets up the snappers to be used."""
        # external arc snapper
        self._ex_arc_snapper = SnapExteriorArc()
        if self._active_mesh:
            self._ex_arc_snapper.set_grid(self._active_mesh, False)
        else:
            self._ex_arc_snapper.set_grid(self._co_grid, False)
        # internal arc snapper
        self._in_arc_snapper = SnapInteriorArc()
        self._in_arc_snapper.set_grid(self._co_grid, False)
        # point arc snapper
        self._pt_arc_snapper = SnapNodeString()
        self._pt_arc_snapper.set_grid(self._co_grid, False)
        # point snapper
        self._pt_snapper = SnapPoint()
        self._pt_snapper.set_grid(self._co_grid, False)
