"""Apply a Boundary Conditions coverage to an ADCIRC simulation."""

__copyright__ = "(C) Copyright Aquaveo 2019"
__license__ = "All rights reserved"


# 1. Standard Python modules
from functools import cached_property
import math
import os
from typing import cast
import uuid

# 2. Third party modules
import numpy as np
import pandas as pd
from rtree import index
import xarray as xr

# 3. Aquaveo modules
from xms.components.display.display_options_io import (
    read_display_options_from_json, write_display_option_line_locations, write_display_options_to_json
)
from xms.components.display.xms_display_message import DrawType, XmsDisplayMessage
from xms.constraint import read_grid_from_file, UGrid2d
from xms.core.filesystem import filesystem as io_util
from xms.coverage.arcs.arc_util import arcs_have_compatible_directions
from xms.data_objects.parameters import Component, FilterLocation, UGrid
from xms.grid.ugrid import UGrid as XmUGrid
from xms.guipy.data.category_display_option_list import CategoryDisplayOptionList
from xms.guipy.data.line_style import LineOptions
from xms.guipy.data.target_type import TargetType
from xms.HydraulicToolboxCalc.hydraulics.manning_n.manning_n_data import ManningNData
from xms.HydraulicToolboxCalc.model.hydraulic_toolbox import HydraulicToolbox
from xms.snap.snap_exterior_arc import SnapExteriorArc
from xms.snap.snap_point import SnapPoint

# 4. Local modules
from xms.adcirc.components.bc_component_display import BC_POINT_ID_FILE, BcComponentDisplay
from xms.adcirc.data import bc_data, mapped_flow_data as mfd
from xms.adcirc.data.adcirc_data import UNINITIALIZED_COMP_ID
from xms.adcirc.feedback.xmlog import XmLog
from xms.adcirc.file_io.grid_crc import compute_grid_crc
from xms.adcirc.mapping.mapping_util import (coordinate_hash, get_parametric_lengths, linear_interp_with_idw_extrap,
                                             map_levee_atts, populate_levee_atts_from_arcs)


class BcMapper:
    """Class for mapping BC coverage arcs and their attributes to a mesh linked to a simulation."""
    def __init__(
        self, new_main_file, source_comp, source_mesh, source_cov, flow_geoms, flow_amps, flow_phases, is_main_coverage,
            sim_duration, xd
    ):
        """Construct the mapper.

        Args:
            new_main_file (:obj:`str`): File location where the new mapped component's main file should be created.
                Assumed to be in a unique UUID folder in the temp directory.
            source_comp (:obj:`BcComponent`): The source component to map (the attributes)
            source_mesh (:obj:`xms.data_objects.parameters.UGrid`): The domain mesh. Should be an xmsconstraint impl
            source_cov (:obj:`xms.data_objects.parameters.Coverage`): The source BC coverage geometry to map
            flow_geoms (:obj:`dict`): {:obj:`str`: :obj:`InterpLinear`} Periodic flow forcing geometry UUID to geometry
                interpolator
            flow_amps (:obj:`list`): [:obj:`xms.datasets.dataset_reader.DatasetReader`] List of periodic flow forcing
                amplitude datasets. Parallel with flow_phases.
            flow_phases (:obj:`list`): [:obj:`xms.datasets.dataset_reader.DatasetReader`] List of periodic flow forcing
                phase datasets. Parallel with flow_amps.
            is_main_coverage: False if this is a levee only coverage
            sim_duration: (float) Duration of simulation in days.
            xd (:obj:`XmsData`): The XMS inter-process communicator
        """
        self.LEVEE_COL_LENGTH = 0
        self.LEVEE_COL_CREST = 1
        self.LEVEE_COL_SUB_COEF = 2
        self.LEVEE_COL_SUP_COEF = 3
        self.mapped_comp = None
        self._mapped_file = new_main_file
        self._bc_comp = source_comp
        self._bc_arcs = source_cov.arcs
        # Don't even bother setting up anything else if the mesh has non-tri elements
        XmLog().instance.info('Reading the ADCIRC mesh...')
        self._grid_file = source_mesh.cogrid_file
        self._ugrid = self._co_grid.ugrid
        self._node_bc_type = np.full(self._ugrid.point_count, -1)
        self._wkt = source_mesh.projection.well_known_text

        self._arcs = {arc.id: arc for arc in self._bc_arcs}
        self._cov_uuid = source_cov.uuid
        self._cov_name = source_cov.name
        self._pipe_pts = source_cov.get_points(FilterLocation.PT_LOC_DISJOINT)

        self._flow_times = None
        self._sim_dur = sim_duration
        self._xd = xd
        self._wkt = self._xd.display_wkt
        self._coord_sys = self._xd.coordinate_system.upper()

        self._self_closing_mainland_arcs = set()
        self._mainland_island_overlapping_arcs = []
        self._no_snap_arcs = set()

        self._flow_geoms = flow_geoms
        self._flow_amps = flow_amps
        self._flow_phases = flow_phases
        self._is_main_coverage = is_main_coverage
        self.idw_interps = {}
        XmLog().instance.info('Finished constructing the arc snapping object.')

    @cached_property
    def _arc_snapper(self) -> SnapExteriorArc:
        """An arc snapper for self._co_grid."""
        XmLog().instance.info('Constructing the arc snapping object.')
        snapper = SnapExteriorArc()
        ugrid = cast(UGrid, self._co_grid)
        snapper.set_grid(grid=ugrid, target_cells=False)
        return snapper

    @cached_property
    def _pt_snapper(self) -> SnapPoint:
        """A point snapper for self._co_grid."""
        XmLog().instance.info('Constructing the point snapping object for pipes.')
        snapper = SnapPoint()
        ugrid = cast(UGrid, self._co_grid)
        snapper.set_grid(grid=ugrid, target_cells=False)
        return snapper

    @cached_property
    def _co_grid(self) -> UGrid2d:
        """
        The constrained grid linked to the simulation.

        Raises RuntimeError if none is linked.
        """
        grid = read_grid_from_file(self._grid_file)
        if not grid.check_all_cells_are_of_type(XmUGrid.cell_type_enum.TRIANGLE):
            raise RuntimeError(
                'Linked mesh has elements that are not triangles. Ensure all elements are triangles '
                'before attempting to apply Boundary Conditions to the simulation.'
            )
        return grid

    def map_data(self):
        """Create a mapped BC component in the given location from the given source component and snap data.

        Returns:
            (:obj:`xms.data_objects.parameters.Component`): The mapped BC XMS component object
        """
        XmLog().instance.info('Copying source Boundary Condition coverage attributes to applied data folder.')
        # Copy the source component's mainfile and display options files to the mapped component directory.
        # This assumes we are in the component temp directory.
        io_util.copyfile(
            self._bc_comp.main_file, os.path.join(os.path.dirname(self._mapped_file), bc_data.BC_MAIN_FILE)
        )
        mapped_disp = os.path.join(
            os.path.dirname(self._mapped_file), os.path.basename(self._bc_comp.disp_opts_files[0])
        )
        io_util.copyfile(self._bc_comp.disp_opts_files[0], mapped_disp)
        from xms.adcirc.components.mapped_bc_component import MappedBcComponent
        self.mapped_comp = MappedBcComponent(self._mapped_file)

        # Save CRC of the grid file for future reference. If user edits the mesh, this mapping is invalid.
        self.mapped_comp.data.info.attrs['grid_crc'] = compute_grid_crc(self._grid_file)

        # Set projection of free location draw to be the current display projection (should match native projection
        # of the ADCIRC domain mesh).
        self.mapped_comp.data.info.attrs['wkt'] = self._wkt

        self._update_display_uuids(mapped_disp, self._bc_comp.disp_opts_files[1])

        # Create the mapped BC Python component
        XmLog().instance.info('Creating applied boundary condition object.')
        self._create_map_component()
        XmLog().instance.info('Writing applied boundary condition data files.')
        self.mapped_comp.data.commit()

        # Create the data_objects component
        name = '(applied)' if self._is_main_coverage else '(applied levees)'
        do_comp = Component(
            name=f'{self._cov_name} {name}',
            comp_uuid=os.path.basename(os.path.dirname(self._mapped_file)),
            main_file=self._mapped_file,
            model_name='ADCIRC',
            unique_name='Mapped_Bc_Component',
            locked=False
        )

        # Map river constituents if needed.
        flow_comp = self._create_map_flow_component()

        # log the warnings
        self._log_warnings()

        return do_comp, flow_comp

    def _log_warnings(self):
        """Log the warnings."""
        overlapping = len(self._mainland_island_overlapping_arcs)
        self_closing = len(self._self_closing_mainland_arcs)
        no_snap = len(self._no_snap_arcs)
        if overlapping > 0 or self_closing > 0 or no_snap > 0:
            XmLog().instance.warning('Mapping warnings:')
            if overlapping > 0:
                msg = ('The following mainland and island arcs overlap each other. This is \n'
                       'typical when an island does not line up with a hole in the mesh.\n')
                for i, arcs in enumerate(self._mainland_island_overlapping_arcs):
                    if i != 0:
                        msg += ', '
                    if i != 0 and i % 7 == 0:
                        msg += '\n'
                    msg += f'{arcs}'
                XmLog.instance.warning(msg)
            if self_closing > 0:
                msg = ('The following mainland arcs form a loop. These would more commonly \n'
                       'be assigned as island arcs.\n')
                for i, arc in enumerate(sorted(self._self_closing_mainland_arcs)):
                    if i != 0:
                        msg += ', '
                    if i != 0 and i % 7 == 0:
                        msg += '\n'
                    msg += f'{arc}'
                XmLog.instance.warning(msg)
            if no_snap > 0:
                msg = ('The following arcs did not snap to the mesh. This may be due to a \n'
                       'degenerate arc or an arc that crosses over the mesh.\n')
                for i, arc in enumerate(sorted(self._no_snap_arcs)):
                    if i != 0:
                        msg += ', '
                    if i != 0 and i % 7 == 0:
                        msg += '\n'
                    msg += f'{arc}'
                XmLog.instance.warning(msg)

    def _update_display_uuids(self, mapped_disp, mapped_point_disp):
        """Create unique UUIDs for the display lists copied from the source coverage component.

        Args:
            mapped_disp (:obj:`str`): Filepath to the mapped BC arc display options
            mapped_point_disp: Filepath to the mapped BC point display options
        """
        # Read the source arc display options, and save ourselves a copy with a randomized UUID.
        XmLog().instance.info('Copying source Boundary Condition coverage display options to applied data folder.')
        categories = CategoryDisplayOptionList()
        json_dict = read_display_options_from_json(mapped_disp)
        categories.from_dict(json_dict)

        # Append the pipe category to the list. We draw pipes as lines in the mapped component display.
        pipe_categories = CategoryDisplayOptionList()
        json_dict = read_display_options_from_json(mapped_point_disp)
        pipe_categories.from_dict(json_dict)
        cat_opts = pipe_categories.categories[0]
        pt_opts = cat_opts.options
        line_opts = LineOptions()
        line_opts.color = pt_opts.color
        line_opts.width = max(1, int(pt_opts.size / 2))
        cat_opts.options = line_opts
        categories.categories.append(cat_opts)

        categories.comp_uuid = os.path.basename(os.path.dirname(self._mapped_file))
        categories.uuid = str(uuid.uuid4())  # Generate a new UUID for the mapped component display
        categories.is_ids = False  # Switch to a free location draw
        # Set projection of free locations to be that of the mesh/current display
        categories.projection = {'wkt': self._wkt}
        write_display_options_to_json(mapped_disp, categories)

        # Save our display list UUID to the main file
        self.mapped_comp.data.info.attrs['display_uuid'] = categories.uuid
        self.mapped_comp.display_option_list = [
            XmsDisplayMessage(file=self.mapped_comp.disp_opts_files[0], draw_type=DrawType.draw_at_locations),
        ]

    def _create_map_component(self):
        """Create the mapped BC component Python object."""
        XmLog().instance.info('Snapping boundary condition arcs to domain mesh boundaries.')
        snap_arc_locations = [[], [], [], [], [], [], [], [], [], []]
        snap_pt_locations = []
        # Loop through all component ids defined on the source coverage.
        comp_map = self._bc_comp.comp_to_xms.get(self._bc_comp.cov_uuid, {})
        arc_comp_ids = comp_map.get(TargetType.arc, {})
        if not arc_comp_ids:
            XmLog().instance.warning(
                'The Boundary Conditions coverage has no ADCIRC boundary arcs defined. Reapply the coverage once '
                'Boundary Condition specification is complete.'
            )
        comp_id_data = []
        partner_id_data = []
        nodes_start_idx_data = []
        node_count_data = []
        nodes = []

        levee_comp_id_data = []
        levee_node1_id_data = []
        levee_node2_id_data = []
        levee_height_data = []
        levee_sub_coef_data = []
        levee_super_coef_data = []
        pipe_on_data = []
        pipe_height_data = []
        pipe_diameter_data = []
        pipe_coef_data = []
        node_id_to_pipe_ids = self._snap_pipe_pts()  # {node_id: (comp_id, att_id)}

        river_comp_ids = []
        river_node_ids = []
        river_flows = []

        next_nodestring_id = 0
        comp_id_to_atts = {}  # Mapping from component ID to an xarray dataset with values for the feature.
        comp_id_to_levees = {}
        total_bcs = len(arc_comp_ids)
        summary_issues = {
            'Direction mismatch': [],
            'Number of nodes mismatch': [],
            'Snapping failure': [],
            'Missing Z crest': [],
            'Duplicate ocean node': [],
            'Levee arc not paired': []
        }

        used_ocean_nodes = {}
        used_island_nodes = {}
        used_mainland_nodes = {}

        bc_id = 0
        for comp_id, att_ids in arc_comp_ids.items():
            use_elevs = False
            constant_q = True
            supercritical_coeff = 1.0
            subcritical_coeff = 1.0
            if comp_id not in comp_id_to_atts:
                comp_id_to_atts[comp_id] = self._bc_comp.data.arcs.sel(comp_id=comp_id)
            bc_atts = comp_id_to_atts[comp_id]
            bc_type = int(bc_atts['type'].data.item())

            is_levee_pair = bc_type == bc_data.LEVEE_INDEX
            is_levee = bc_type == bc_data.LEVEE_OUTFLOW_INDEX or is_levee_pair
            bc_id += 1

            if len(att_ids) != 2 and is_levee and is_levee_pair:
                XmLog().instance.warning(f'Levee bc {bc_id} does not contain two arcs. This bc will '
                                         'be skipped.')
                continue

            XmLog().instance.info(f'Processing bc {bc_id} of {total_bcs}. (arc ids in BC: {att_ids})')

            if is_levee:
                use_elevs = 'use_elevs' in bc_atts.data_vars and int(bc_atts['use_elevs'].data.item()) == 1
                if 'supercritical_coeff' in bc_atts.data_vars:
                    supercritical_coeff = float(bc_atts['supercritical_coeff'].data.item())
                first_nodestring = None
                first_locs = None
                first_ids = None
            if is_levee_pair:
                if 'subcritical_coeff' in bc_atts.data_vars:
                    subcritical_coeff = float(bc_atts['subcritical_coeff'].data.item())
                second_nodestring = None
                second_locs = None
                second_ids = None
            for att_id in att_ids:
                # Create nodestrings
                snap_arc = self._arc_snapper.get_snapped_points(self._arcs[att_id])
                if len(snap_arc['location']) < 2:
                    self._no_snap_arcs.add(att_id)
                    continue

                # Repeating ocean nodes results in duplicate tidal forcing.
                if bc_type == bc_data.OCEAN_INDEX:
                    if snap_arc['id'][-1] == snap_arc['id'][0]:
                        # If it's a loop, just break off the last node.
                        snap_arc['location'] = snap_arc['location'][:-1]  # tuples and ndarrays have no .pop()
                        snap_arc['id'] = snap_arc['id'][:-1]

                    first, second, node = _find_duplicate_snap(att_id, snap_arc, used_ocean_nodes)
                    if node != -1:
                        msg = f'Arcs {first} and {second} both snapped to node {node + 1}'
                        summary_issues['Duplicate ocean node'].append(msg)
                elif bc_type == bc_data.ISLAND_INDEX:
                    first, second, node = _find_duplicate_snap(att_id, snap_arc, used_island_nodes)
                    if node != -1:
                        msg = (
                            f'Island arcs {first} and {second} both snapped to node {node + 1}.'
                        )
                        XmLog().instance.warning(msg)

                    first, second, node = _find_duplicate_snap(att_id, snap_arc, used_mainland_nodes, False)
                    if node != -1:
                        self._mainland_island_overlapping_arcs.append((first, second))
                elif bc_type == bc_data.MAINLAND_INDEX:
                    if len(snap_arc['id']) > 1 and snap_arc['id'][0] == snap_arc['id'][-1]:
                        self._self_closing_mainland_arcs.add(att_id)
                    first, second, node = _find_duplicate_snap(att_id, snap_arc, used_island_nodes, False)

                    if node != -1:
                        self._mainland_island_overlapping_arcs.append((first, second))

                    # add the mainland nodes to the list
                    for node in snap_arc['id']:
                        used_mainland_nodes[node] = att_id
                elif bc_type == bc_data.RIVER_INDEX:
                    # get the next snapped points by extending the line so that we get one more node on each side
                    arc_locs = self._arcs[att_ids[0]].get_points(FilterLocation.LOC_ALL)
                    pts = [np.array((pt.x, pt.y, pt.z)) for pt in arc_locs]

                    # extrapolate to extend the arc
                    diffs = np.diff(pts, axis=0)
                    seg_lengths = np.linalg.norm(diffs, axis=1)
                    length = np.sum(seg_lengths) * 2.0

                    dir_start = pts[0] - pts[1]
                    dir_start /= np.linalg.norm(dir_start)
                    new_start = pts[0] + dir_start * length

                    dir_end = pts[-2] - pts[-1]
                    dir_end /= np.linalg.norm(dir_end)
                    new_end = pts[-1] - dir_end * length

                    extended_locs = np.vstack([new_start, pts, new_end]).tolist()
                    tmp_snapped_arc = self._arc_snapper.get_snapped_points(extended_locs)
                    extended_snapped_nodes = list(tmp_snapped_arc['id'])
                    extended_snapped_locs = list(tmp_snapped_arc['location'])
                    snapped_nodes = list(snap_arc['id'])
                    snapped_locs = list(snap_arc['location'])

                    # find the original end nodes that snapped
                    start_node = snapped_nodes[0]
                    end_node = snapped_nodes[-1]

                    start_i = extended_snapped_nodes.index(start_node)
                    if start_i > 0:
                        extended_snapped_nodes = extended_snapped_nodes[start_i - 1:]
                        extended_snapped_locs = extended_snapped_locs[start_i - 1:]

                    end_i = extended_snapped_nodes.index(end_node)
                    if end_i < len(extended_snapped_nodes) - 1:
                        extended_snapped_nodes = extended_snapped_nodes[:end_i + 2]
                        extended_snapped_locs = extended_snapped_locs[:end_i + 2]

                    if self._coord_sys == 'GEOGRAPHIC':
                        extended_snapped_locs = bc_data.convert_lat_lon_pts_to_utm(extended_snapped_locs, self._wkt,
                                                                                   None)
                        snapped_locs = extended_snapped_locs[1:-1]

                    # get the segment lengths
                    half_seg_lengths = []
                    mid_pts = []
                    for i in range(len(extended_snapped_locs) - 1):
                        pt1 = extended_snapped_locs[i]
                        pt2 = extended_snapped_locs[i + 1]
                        diff = np.subtract(pt1, pt2)
                        dist = np.linalg.norm(diff)

                        half_seg_length = dist / 2.0
                        half_seg_lengths.append(half_seg_length)

                        # we add this length twice unless it is on the end
                        if i != 0 and i < len(extended_snapped_locs) - 2:
                            half_seg_lengths.append(half_seg_length)

                        mid_pts.append(tuple((a + b) / 2 for a, b in zip(pt1, pt2)))

                    cross_section_def = [mid_pts[0]] + snapped_locs + [mid_pts[-1]]

                    # is it constant or do we need a time series
                    if 'constant_q' in bc_atts.data_vars:
                        constant_q = float(bc_atts['constant_q'].data.item())

                    if constant_q < 0.0:
                        input_q = self._bc_comp.data.q.sel(comp_id=comp_id, drop=True)
                    else:
                        data_dict = {'Time': [0.0], 'Flow': [constant_q], 'comp_id': [comp_id]}
                        input_q = pd.DataFrame.from_dict(data_dict).to_xarray()

                    flows = list(self._get_interpolated_flows(input_q))
                    wses = self._get_wses(flows, cross_section_def)

                    if 'wse_offset' in bc_atts.data_vars:
                        wse_offset = float(bc_atts['wse_offset'].data.item())
                        if wse_offset != 0.0:
                            wses = [wse + wse_offset for wse in wses]

                    flow_dict = {}
                    for ts, wse_ts in enumerate(wses):
                        q_segs = 0.0
                        node_qs = []
                        for i, node_loc in enumerate(snapped_locs):
                            if ts == 0:
                                # create the entry in the dictionary
                                flow_dict[snapped_nodes[i]] = []
                            half_seg_i = i * 2
                            q1, q2 = self._compute_flow_at_node(node_loc[2], mid_pts[i][2], mid_pts[i + 1][2], wse_ts,
                                                                half_seg_lengths[half_seg_i],
                                                                half_seg_lengths[half_seg_i + 1])
                            node_q = q1 + q2
                            node_qs.append(node_q)
                            q_segs += node_q

                        # make the scaling factor since they won't match up exactly
                        scale_factor = flows[ts] / q_segs
                        for i, node_q in enumerate(node_qs):
                            flow_dict[snapped_nodes[i]].append(node_q * scale_factor)

                    # store this with all other flow data
                    for node_id in snapped_nodes:
                        river_comp_ids.append(comp_id)
                        river_node_ids.append(node_id)
                        river_flows.append(flow_dict[node_id])

                # For levees, store off first and second nodestring ids in case this is a levee pair
                if is_levee:
                    if first_nodestring is None:
                        first_nodestring = next_nodestring_id
                        first_locs = snap_arc['location']
                        if use_elevs:
                            arc_locs = self._arcs[att_ids[0]].get_points(FilterLocation.LOC_ALL)
                            idx = index.Index()
                            for i, loc in enumerate(arc_locs):
                                idx.insert(i, [loc.x, loc.y])
                            for i in range(len(first_locs)):
                                # find the closest arc location and use the z
                                nearest = list(idx.nearest([first_locs[i][0], first_locs[i][1]]))
                                first_locs[i][2] = arc_locs[nearest[0]].z
                        first_arc_id = att_id
                if is_levee_pair:
                    snap_arc2 = self._arc_snapper.get_snapped_points(self._arcs[att_ids[1]])
                    second_nodestring = next_nodestring_id
                    second_locs = snap_arc2['location']
                    if use_elevs:
                        arc_locs = self._arcs[att_ids[1]].get_points(FilterLocation.LOC_ALL)
                        idx = index.Index()
                        for i, loc in enumerate(arc_locs):
                            idx.insert(i, [loc.x, loc.y])
                        for i in range(len(second_locs)):
                            # find the closest arc location and use the z
                            nearest = list(idx.nearest([second_locs[i][0], second_locs[i][1]]))
                            second_locs[i][2] = arc_locs[nearest[0]].z
                    # make sure the number of nodes is consistent across the levee
                    if len(first_locs) != len(second_locs):
                        msg = (
                            f'Mismatched number of nodes in levee defined by feature arcs {att_ids}.\n'
                            f' Node string for arc {first_arc_id} has {len(first_locs)} nodes.\n'
                            f' Node string for arc {att_id} has {len(second_locs)} nodes.'
                        )
                        summary_issues['Number of nodes mismatch'].append(msg)
                next_nodestring_id += 1

                # Add the snap locations to the display list
                snap_arc_locations[bc_type].append([coord_val for node in snap_arc['location'] for coord_val in node])

                # Build up the dataset
                comp_id_data.append(comp_id)
                partner_id_data.append(UNINITIALIZED_COMP_ID)
                nodes_start_idx_data.append(len(nodes))
                node_count_data.append(len(snap_arc['location']))
                snap_indexes = list(snap_arc['id'])
                node_ids = [node_id + 1 for node_id in snap_indexes]
                if is_levee:
                    if first_ids is None:
                        first_ids = node_ids
                if is_levee_pair:
                    if first_ids is not None:
                        second_ids = node_ids
                nodes.extend(node_ids)

                total_length = 0.0
                sorted_levee = []
                if bc_type == bc_data.LEVEE_OUTFLOW_INDEX:
                    defined = True
                    if comp_id not in comp_id_to_levees:
                        try:
                            comp_id_to_levees[comp_id] = self._bc_comp.data.levees.sel(comp_id=comp_id)
                            levees = comp_id_to_levees[comp_id]
                            sorted_levee = levees.sortby('Parametric __new_line__ Length')
                            total_length = sorted_levee['Parametric __new_line__ Length'][-1].data.item()
                        except KeyError:  # Dumb user is trying to map an undefined outflow levee
                            defined = False
                            XmLog().instance.error(f'Outflow levee on feature arc {att_id} is undefined.')
                    else:
                        comp_id_to_levees[comp_id] = self._bc_comp.data.levees.sel(comp_id=comp_id)
                        levees = comp_id_to_levees[comp_id]
                        sorted_levee = levees.sortby('Parametric __new_line__ Length')
                        total_length = sorted_levee['Parametric __new_line__ Length'][-1].data.item()

                    para_len = get_parametric_lengths(snap_arc['location'])
                    for node_id, length in zip(first_ids, para_len):
                        if defined:
                            length *= total_length
                            height, _ = map_levee_atts(sorted_levee, length)
                            levee_comp_id_data.append(comp_id)
                            levee_node1_id_data.append(node_id)
                            levee_node2_id_data.append(UNINITIALIZED_COMP_ID)
                            levee_height_data.append(height)
                            levee_sub_coef_data.append(subcritical_coeff)
                            levee_super_coef_data.append(supercritical_coeff)
                        else:  # Dumb user is trying to map an undefined outflow levee
                            levee_comp_id_data.append(comp_id)
                            levee_node1_id_data.append(UNINITIALIZED_COMP_ID)
                            levee_node2_id_data.append(UNINITIALIZED_COMP_ID)
                            levee_height_data.append(0.0)
                            levee_sub_coef_data.append(1.0)
                            levee_super_coef_data.append(1.0)
                        pipe_on_data.append(0)
                        pipe_height_data.append(0.0)
                        pipe_diameter_data.append(0.0)
                        pipe_coef_data.append(0.0)

            # ensure that snapping worked (2 arcs for levee pairs, 1 arc for others)
            # if snapping had issues, prevent the completion of processing the data as a levee pair
            if is_levee_pair and first_nodestring == second_nodestring:
                msg = f'Arc {att_ids[0]} is not part of a pair. Failed to map to the mesh.'
                is_levee_pair = False
                summary_issues['Levee arc not paired'].append(msg)
            if is_levee_pair and len(second_locs) < 1:
                is_levee_pair = False
                msg = f'Arc {att_ids[1]} failed to map to the mesh.'
                summary_issues['Snapping failure'].append(msg)
            if is_levee and len(first_locs) < 1:
                is_levee_pair = False
                msg = f'Arc {att_ids[0]} failed to map to the mesh.'
                summary_issues['Snapping failure'].append(msg)

            # finish processing the levee pair
            if is_levee_pair:
                if comp_id not in comp_id_to_levees:
                    comp_id_to_levees[comp_id] = self._bc_comp.data.levees.where(
                        self._bc_comp.data.levees.comp_id == comp_id, drop=True
                    )
                levees = comp_id_to_levees[comp_id]
                # Check if the user is being dumb and did not define the levee.
                if use_elevs:
                    levee_locs = [first_locs, second_locs]
                    lengths_and_zs = np.array(populate_levee_atts_from_arcs(levee_locs, None))
                    para_lengths = lengths_and_zs[:, 0]
                    zcrests = lengths_and_zs[:, 1]
                    num_rows = len(para_lengths)
                    data_dict = {
                        'Parametric __new_line__ Length': para_lengths,
                        'Parametric __new_line__ Length 2': para_lengths,
                        'Zcrest (m)': zcrests,
                        'Subcritical __new_line__ Flow Coef': [subcritical_coeff] * num_rows,
                        'Supercritical __new_line__ Flow Coef': [supercritical_coeff] * num_rows,
                        'comp_id': [comp_id] * num_rows,
                    }
                    levees = pd.DataFrame.from_dict(data_dict).to_xarray()
                elif levees.sizes['comp_id'] == 0:
                    msg = f'Zcrests not defined for levee defined by arcs {att_ids}.'
                    summary_issues['Missing Z crest'].append(msg)
                    continue

                # Check if the levee arcs go in opposite directions - swap the second and give a message
                arc1 = None
                arc2 = None
                # this is very slow. We should find a better way than looping through all the arcs
                feature_ids = self._bc_comp.comp_to_xms.get(self._cov_uuid, {}).get(TargetType.arc, {}).get(comp_id, [])
                for arc in self._bc_arcs:
                    if arc.id in feature_ids:
                        if arc1 is None:
                            arc1 = arc
                        else:
                            arc2 = arc
                            break
                if not arcs_have_compatible_directions(arc1, arc2):
                    # log message that the arcs are incompatible
                    msg = f'Arc directions incompatible. Swapped second arc ({att_ids[1]}) for mapping.'
                    summary_issues['Direction mismatch'].append(msg)
                    second_ids = np.flip(second_ids)
                    second_locs = np.flip(second_locs, 0)
                    second_locs = np.ascontiguousarray(second_locs)
                    # swap the last n nodes
                    nodes = nodes[:-len(second_ids)]
                    nodes.extend(second_ids)

                sorted_levee1 = levees.sortby('Parametric __new_line__ Length')
                total_length1 = sorted_levee1['Parametric __new_line__ Length'][-1].data.item()
                total_length2 = total_length1
                sorted_levee2 = sorted_levee1
                # flag for when there are 2 parametric curves (for each side of the levee) - only happens when unmapped
                two_curves = use_elevs
                # If the attributes and the geometry of this levee have not changed, use the second side.
                if not two_curves:
                    n = self._bc_comp.data.levee_flags.sizes['comp_id']
                    if n != 0 and self._bc_comp.data.levee_flags.use_second_side.loc[comp_id].item() == 1:
                        locs_hash = coordinate_hash(first_locs, second_locs)
                        if locs_hash == self._bc_comp.data.levee_flags.locs.loc[comp_id].item():
                            two_curves = True
                if two_curves:
                    sorted_levee2 = levees.sortby('Parametric __new_line__ Length 2')
                    total_length2 = sorted_levee2['Parametric __new_line__ Length 2'][-1].data.item()
                    sorted_levee2['Parametric __new_line__ Length'] =\
                        sorted_levee2['Parametric __new_line__ Length 2']
                para_len1 = get_parametric_lengths(first_locs)
                para_len2 = get_parametric_lengths(second_locs)
                for pt1, pt2, len1, len2 in zip(first_ids, second_ids, para_len1, para_len2):
                    len1 *= total_length1
                    len2 *= total_length2
                    height1, idx1 = map_levee_atts(sorted_levee1, len1)
                    height2, idx2 = map_levee_atts(sorted_levee2, len2)
                    levee_comp_id_data.append(comp_id)
                    levee_node1_id_data.append(pt1)
                    levee_node2_id_data.append(pt2)
                    levee_height_data.append(max(height1, height2))
                    levee_sub_coef_data.append(subcritical_coeff)
                    levee_super_coef_data.append(supercritical_coeff)

                    # Check if a pipe point snapped to either node of the levee node pair.
                    pipe_node = None
                    pt1_idx = pt1 - 1  # Convert from 0-based UGrid node index to 1-based mesh node id
                    pt2_idx = pt2 - 1
                    if pt1_idx in node_id_to_pipe_ids:
                        pipe_node = pt1_idx
                    elif pt2_idx in node_id_to_pipe_ids:
                        pipe_node = pt2_idx

                    if pipe_node is not None:  # There is a pipe on this levee
                        pipe_on_data.append(1)
                        # Select the source pipe attributes.
                        pipe_comp_id = node_id_to_pipe_ids[pipe_node][0][0]
                        # Get the locations of the two levee nodes in this pair so we can draw pipes on them.
                        pipe_line = list(self._ugrid.get_point_location(pt1_idx))
                        pipe_line.extend(self._ugrid.get_point_location(pt2_idx))
                        snap_pt_locations.append(pipe_line)
                        try:
                            pipe_data = self._bc_comp.data.pipes.sel(comp_id=pipe_comp_id)
                            pipe_height_data.append(pipe_data['Height'].data.item())
                            pipe_diameter_data.append(pipe_data['Diameter'].data.item())
                            pipe_coef_data.append(pipe_data['Coefficient'].data.item())
                        except KeyError:  # This happens when a pipe point is created but attributes are never defined.
                            XmLog.instance.warning(
                                f'Undefined pipe found with feature ID {node_id_to_pipe_ids[pipe_node][0][1]}'
                            )
                            pipe_height_data.append(0.0)
                            pipe_diameter_data.append(0.0)
                            pipe_coef_data.append(0.0)
                        node_id_to_pipe_ids.pop(pipe_node)
                    else:  # No pipe on this levee
                        pipe_on_data.append(0)
                        pipe_height_data.append(0.0)
                        pipe_diameter_data.append(0.0)
                        pipe_coef_data.append(0.0)
                partner_id_data[-1] = first_nodestring
                partner_id_data[-2] = second_nodestring

        # present summary of the levee issues found during mapping
        reports = {category: messages for category, messages in summary_issues.items() if len(messages) > 0}
        if reports:
            XmLog().instance.error('Summary of mapping errors.')

        for category, messages in reports.items():
            XmLog().instance.info(category)
            for message in messages:
                XmLog().instance.error(message)

        # Ensure that all the pipe points mapped to a levee pair.
        if node_id_to_pipe_ids:
            id_strs = []
            for _, pipe_ids in node_id_to_pipe_ids.items():
                id_str = ', '.join([str(pipe[1]) for pipe in pipe_ids])
                id_strs.append(id_str)
            id_str = ', '.join(id_strs)
            XmLog(
            ).instance.error(f'The pipe points with the following ids did not snap to a levee pair node: {id_str}.')

        data_dict = {
            'comp_id': xr.DataArray(data=np.array(comp_id_data, dtype=np.int32)),
            'partner_id': xr.DataArray(data=np.array(partner_id_data, dtype=np.int32)),
            'nodes_start_idx': xr.DataArray(data=np.array(nodes_start_idx_data, dtype=np.int32)),
            'node_count': xr.DataArray(data=np.array(node_count_data, dtype=np.int32)),
        }
        self.mapped_comp.data.nodestrings = xr.Dataset(data_vars=data_dict)
        node_dict = {'id': xr.DataArray(data=np.array(nodes, dtype=np.int32))}
        self.mapped_comp.data.nodes = xr.Dataset(data_vars=node_dict)
        levee_dict = {
            'Node1 Id': ('comp_id', np.array(levee_node1_id_data, dtype=np.int32)),
            'Node2 Id': ('comp_id', np.array(levee_node2_id_data, dtype=np.int32)),
            'Zcrest (m)': ('comp_id', np.array(levee_height_data, dtype=np.float64)),
            'Subcritical __new_line__ Flow Coef': ('comp_id', np.array(levee_sub_coef_data, dtype=np.float64)),
            'Supercritical __new_line__ Flow Coef': ('comp_id', np.array(levee_super_coef_data, dtype=np.float64)),
            'Pipe': ('comp_id', np.array(pipe_on_data, dtype=np.int32)),
            'Zpipe (m)': ('comp_id', np.array(pipe_height_data, dtype=np.float64)),
            'Pipe __new_line__ Diameter (m)': ('comp_id', np.array(pipe_diameter_data, dtype=np.float64)),
            'Bulk __new_line__ Coefficient': ('comp_id', np.array(pipe_coef_data, dtype=np.float64)),
        }
        coords = {'comp_id': levee_comp_id_data}
        self.mapped_comp.data.levees = xr.Dataset(data_vars=levee_dict, coords=coords)

        if len(river_comp_ids) > 0:
            flow_array = np.array(river_flows)
            self.mapped_comp.data.river_flows = xr.Dataset(
                {
                    'Flow': (('Node Id', 'TS'), flow_array)
                },
                coords={
                    'Node Id': river_node_ids,
                    'TS': self._flow_times,
                    'comp_id': ('Node Id', river_comp_ids)
                }
            )

        XmLog().instance.info('Creating display lists for applied boundary condition object.')
        for bc_type in range(bc_data.FLOW_AND_RADIATION_INDEX + 1):
            if snap_arc_locations[bc_type]:
                loc_file = BcComponentDisplay.get_display_id_file(bc_type, os.path.dirname(self.mapped_comp.main_file))
                write_display_option_line_locations(loc_file, snap_arc_locations[bc_type])

        if snap_pt_locations:
            loc_file = os.path.join(os.path.dirname(self.mapped_comp.main_file), BC_POINT_ID_FILE)
            write_display_option_line_locations(loc_file, snap_pt_locations)

    def _compute_flow_at_node(self, node_z, mid_1_z, mid_2_z, wse, half_seg_length_1, half_seg_length_2):
        """Create the mapped BC component Python object."""
        q_1 = self._compute_flow_for_half_seg(mid_1_z, node_z, wse, half_seg_length_1)
        q_2 = self._compute_flow_for_half_seg(node_z, mid_2_z, wse, half_seg_length_2)

        return q_1, q_2

    def _compute_flow_for_half_seg(self, z1, z2, wse, seg_len):
        """Create the mapped BC component Python object."""
        n = 0.03
        sf = 0.001

        a, pw, rh = self._compute_flow_variables(z1, z2, wse, seg_len)
        flow = (1 / n) * a * np.power(rh, (2.0 / 3.0)) * np.power(sf, 0.5) if a > 0.0 else 0.0
        return flow

    def _compute_flow_variables(self, z1, z2, wse, seg_len):
        """Create the mapped BC component Python object."""
        a = pw = rh = 0.0
        if wse > min(z1, z2):  # this is wet
            d1 = wse - z1
            d2 = wse - z2
            if d1 > 0.0 and d2 > 0.0:  # all wet
                dave = (d1 + d2) / 2.0
                a = seg_len * dave
                pw = np.sqrt(np.power(np.abs(d1 - d2), 2.0) + np.power(seg_len, 2.0))
            else:
                if d1 > 0.0:  # first half is wet
                    d = d1
                    wet_len = (d1 / (d1 - d2)) * seg_len
                else:  # second half is wet
                    d = d2
                    wet_len = (d2 / (d2 - d1)) * seg_len
                a = (d * wet_len) / 2.0
                pw = np.sqrt(np.power(d, 2.0) + np.power(wet_len, 2.0))

            rh = a / pw

        return a, pw, rh

    def _get_flow_times(self):
        """Create the mapped BC component Python object."""
        if self._flow_times is not None:
            return self._flow_times

        self._flow_times = []
        flow_ts = self._bc_comp.data.info.attrs['flow_ts']
        flow_ts_units = self._bc_comp.data.info.attrs['flow_ts_units']
        if flow_ts_units == 'seconds':
            duration = self._sim_dur * 86400
        elif flow_ts_units == 'minutes':
            duration = self._sim_dur * 1440
        elif flow_ts_units == 'hours':
            duration = self._sim_dur * 24
        else:
            duration = self._sim_dur

        cur_ts = 0
        while cur_ts < duration:
            self._flow_times.append(cur_ts)
            cur_ts += flow_ts

        return self._flow_times

    def _get_interpolated_flows(self, q):
        """Create the mapped BC component Python object."""
        self._get_flow_times()

        orig_times = list(q['Time'])
        orig_flows = list(q['Flow'])

        return np.interp(self._flow_times, orig_times, orig_flows)

    def _get_wses(self, flows, xs_locs):
        """Create the mapped BC component Python object."""
        sf = 0.001
        n = 0.03

        if self._coord_sys == 'GEOGRAPHIC':
            xs_locs = bc_data.convert_lat_lon_pts_to_utm(xs_locs, self._wkt, None)

        xy_locs = [(loc[0], loc[1]) for loc in xs_locs]
        elevations = [loc[2] for loc in xs_locs]
        stations = [0.0]

        for i in range(1, len(xy_locs)):
            x0, y0 = xy_locs[i - 1]
            x1, y1 = xy_locs[i]
            dx = x1 - x0
            dy = y1 - y0
            seg_length = math.hypot(dx, dy)
            stations.append(stations[-1] + seg_length)

        hyd_toolbox = HydraulicToolbox()
        hyd_toolbox.app_data.set_setting('Selected unit system', 'SI Units (Metric)')
        project_uuid = hyd_toolbox.get_first_project_uuid()
        manning_n = ManningNData(app_data=hyd_toolbox.app_data, model_name=hyd_toolbox.model_name,
                                 project_uuid=project_uuid)
        manning_n.input['Calculate'].set_index('Head')
        manning_n.set_shape('cross-section')

        manning_n.set_stations_and_elevation(stations, elevations)
        manning_n.input['Slope'].set_val(sf)
        manning_n.input['Composite n'].get_val().input['Composite n'].set_val(n)
        manning_n.input['Flows'].get_val().set_user_list(flows, app_data=hyd_toolbox.app_data)

        manning_n.compute_data()

        return manning_n.results['WSE'].get_val(hyd_toolbox.app_data)

    def _snap_pipe_pts(self):
        """Snap pipe points to the domain mesh.

        Returns:
            (:obj:`dict`): Dictionary whose keys are node ids and whose values are lists of tuples containing the
            snapped pipe point's component id and XMS att id. If this list contains more than one tuple, an error
            is logged.
        """
        # Snap the pipe points to the mesh. Will map to levee pairs later.
        node_id_to_pipe_ids = {}  # {node_id: (comp_id, att_id)}
        for pipe_pt in self._pipe_pts:
            # Snap the pipe points to the closest mesh node (should be on a levee).
            snap_pt = self._pt_snapper.get_snapped_point(pipe_pt)
            if snap_pt['id']:
                node_id = snap_pt['id']
                if node_id not in node_id_to_pipe_ids:
                    # We will build up a list of the pipe component ids that snap to each node id, but it is an
                    # error if multiple pipes snap to the same node.
                    node_id_to_pipe_ids[node_id] = []
                comp_id = self._bc_comp.get_comp_id(TargetType.point, pipe_pt.id)
                node_id_to_pipe_ids[node_id].append((comp_id, pipe_pt.id))
            else:
                XmLog().instance.error(f'Unable to find closest mesh node for pipe pt {pipe_pt.id}.')
        # Check for multiple pipes on the same node.
        for node_id, pipe_ids in node_id_to_pipe_ids.items():
            if len(pipe_ids) > 1:
                id_str = ', '.join([str(pipe[1]) for pipe in pipe_ids])
                XmLog().instance.error(
                    f'Pipe points with ids {id_str} all snap to node {node_id}. This is invalid. Ensure that node '
                    f'{node_id} is the closest mesh node to only one pipe point.'
                )
        return node_id_to_pipe_ids

    def _create_map_flow_component(self):
        """Create the mapped river flow boundary component Python object.

        Returns:
            (:obj:`xms.data_objects.parameters.Component`): The mapped flow component if
            periodic flow boundaries have been defined. Otherwise, :obj:`None`.
        """
        # Get node ids of river boundary nodes
        river_nodes = self.mapped_comp.data.get_river_node_ids()
        river_nodes_0based = [node_id - 1 for node_id in river_nodes]
        if not river_nodes or self._bc_comp.data.info.attrs['periodic_flow'] == 0:
            return None  # No river boundaries or non-periodic flow

        # Create a new directory for the mapped flow component
        # Make a new folder for the mapped component next to ours
        flow_comp_uuid = str(uuid.uuid4())
        flow_comp_dir = os.path.join(os.path.dirname(os.path.dirname(self._mapped_file)), flow_comp_uuid)
        os.makedirs(flow_comp_dir, exist_ok=True)
        flow_comp_mainfile = os.path.join(flow_comp_dir, mfd.MAPPED_FLOW_MAIN_FILE)
        flow_data = mfd.MappedFlowData(flow_comp_mainfile)

        # Create the data_objects component
        do_comp = Component(
            name='Flow Constituents (applied)',
            comp_uuid=flow_comp_uuid,
            main_file=flow_comp_mainfile,
            model_name='ADCIRC',
            unique_name='Mapped_Flow_Component',
            locked=False
        )

        if not self._flow_amps or len(self._flow_amps) != len(self._flow_phases):
            periodic_flow = self.mapped_comp.data.source_data.info.attrs['periodic_flow']
            if periodic_flow:
                XmLog().instance.warning(
                    'Incomplete periodic flow boundary specification. Unable to apply \n'
                    'periodic flow on river boundaries.'
                )
            return do_comp

        XmLog().instance.info('Interpolating amplitude and phase datasets to river boundary nodes.')
        all_amps = []
        all_phases = []
        river_locs = [(loc[0], loc[1], loc[2]) for loc in self._ugrid.get_points_locations(river_nodes_0based)]
        # Interpolate the selected datasets to the river nodes
        for flow_amp, flow_phase in zip(self._flow_amps, self._flow_phases):
            all_amps.append(
                linear_interp_with_idw_extrap(
                    self._flow_geoms[flow_amp.geom_uuid], river_locs, flow_amp, self.idw_interps
                )
            )

            all_phases.append(
                linear_interp_with_idw_extrap(
                    self._flow_geoms[flow_phase.geom_uuid], river_locs, flow_phase, self.idw_interps
                )
            )

        flow_data.cons = self._get_flow_constituent_properties()
        cons = flow_data.cons['con'].data.tolist()
        flow_coords = {
            'con': cons,
            'node_id': river_nodes,
        }
        flow_values = {
            'amplitude': (('con', 'node_id'), all_amps),
            'phase': (('con', 'node_id'), all_phases),
        }
        flow_data.values = xr.Dataset(data_vars=flow_values, coords=flow_coords).fillna(0.0)
        flow_data.commit()

        return do_comp

    def _get_flow_constituent_properties(self):
        """Extract the flow constituent names, frequencies, nodal factors, and equilibrium arguments.

        Returns:
            (:obj:`xarray.Dataset`): If flow constituents have been defined else None
        """
        # Check for any rows in the table that are not standard constituents
        flow_cons = self._bc_comp.data.flow_cons
        default_index = flow_cons.Name.dims[0]
        try:
            # Extract the user-defined constituent properties from the source flow component.
            flow_cons = flow_cons.set_coords(['Name'])  # Set the constituent name as the index coordinate
            flow_cons = flow_cons.swap_dims({default_index: 'Name'})
            flow_cons = flow_cons.reset_coords([default_index], drop=True)  # Drop the default index
            flow_cons = flow_cons.drop_vars(['Amplitude', 'Phase'])
            flow_cons = flow_cons.rename({  # Rename the coords and data variables
                'Name': 'con',
                'Frequency': 'frequency',
                'Nodal __new_line__ Factor': 'nodal_factor',
                'Equilibrium __new_line__ Argument': 'equilibrium_argument',
            }).sortby('con')
            return flow_cons
        except Exception:
            XmLog().instance.exception('Unable to extract flow constituent properties.')
            return None


def _find_duplicate_snap(att_id, snap_arc: dict, used_nodes: dict, add_to_list=True) -> tuple[int, int, int]:
    """
    Find another arc that snaps to the same node(s) as a provided one.

    Args:
        att_id: Feature ID of the provided arc.
        snap_arc: The provided arc. The return value of SnapExteriorArc.get_snapped_points().
        used_nodes: Dictionary mapping node_id -> feature_id. Used to find other arcs that also snapped to the same
            node. Should be an empty dict the first time this is called. This will update its contents.

    Returns:
        Tuple of (first_feature_id, second_feature_id, shared_node_id). The first element is the smaller of att_id and
        the found other arc. second_feature_id is the greater. node_id is a node that they both snap to.
    """
    for node in snap_arc['id']:
        if node in used_nodes and used_nodes[node] != att_id:
            first, second = sorted([att_id, used_nodes[node]])
            return first, second, node
        elif add_to_list:
            used_nodes[node] = att_id
    return -1, -1, -1
