"""WashMapper class."""

__copyright__ = "(C) Copyright Aquaveo 2023"
__license__ = "All rights reserved"

# 1. Standard Python modules
import copy
import logging

# 2. Third party modules

# 3. Aquaveo modules
from xms.coverage.att_table_coverage_dump import get_att
from xms.coverage.grid.ugrid_mapper import UgridMapper
from xms.coverage.xy import xy_util
from xms.coverage.xy.xy_series import XySeries
from xms.data_objects.parameters import FilterLocation
from xms.grid.ugrid import UGrid
from xms.guipy.data.target_type import TargetType

# 4. Local modules
from xms.wash.components.bc_coverage_component import BcCoverageComponent
from xms.wash.tools.bc import Bc, BcTypeEnum, FaceBc, PointBc, RateTypeEnum, ReliefWell, Well
from xms.wash.tools.types import XM_NONE

# Constant strings that match GMS fconatts.h
FA_NAME = 'Name'
FA_FW_TYPE = 'Type'
FA_FW_WELTOPSCR = 'Top scr.'
FA_FW_WELBOTSCR = 'Bot. scr.'
FA_FW_WELFLOW = 'Flow rate'
FA_FW_WELFLOWTS = 'Flow rate TS'
FA_FW_CONC = 'Conc.'
FA_FW_CONCTS = 'Conc. TS'
FA_FW_FLOWBC = 'Flow bc'
FA_FW_FLUX = 'Flux rate'
FA_FW_FLUXTS = 'Flux rate TS'
FA_FW_TSHEAD = 'spec. head'  # DB1
FA_FW_CHDHEAD = 'Head'
FA_FW_CHDHEADTS = 'Head TS'
FA_FW_MASS = 'Mass flux'
FA_FW_MASSTS = 'Mass flux TS'
FA_FW_TRANSBC = 'Transport bc'
FA_FW_TSCONC = 'spec. conc.'  # DB2
FA_FW_TSFLUX = 'spec. flux'  # CB1
FA_FW_TVFLUX = 'variable flux'  # RS1
FA_FW_TSMFLUX = 'spec. mass flux'  # CB2
FA_FW_TVCONC = 'variable conc.'  # RS2
FA_FW_ZONE = "Zones"
FA_TYPE_WELL = 'well'
FA_TYPE_WELLSUPR = 'well (super node)'


class WashMapper:
    """Maps a coverage to a UGrid and creates a WASH123D BC file.

    Patterned after GMS's fmiMapToFEMWATER() and fwiWriteFemwaterBCData().
    """
    def __init__(
        self,
        att_table_coverage,
        bc_coverage,
        co_grid_3d,
        cell_materials,
        bc_filepath,
        max_sat_k,
        min_resist_coeff,
        node_point_tolerance,
        logger,
        sort=True,
        query=None
    ):
        """Initializes the class.

        Args:
            att_table_coverage (AttTableCoverageDump): Class to read coverage dump file with attributes.
            bc_coverage: coverage with RW information.
            co_grid_3d (UGrid3d): The 3D UGrid.
            cell_materials (list[int] | None): Cell material IDs (optional).
            bc_filepath (Path): Filepath to BC file.
            max_sat_k: maximum saturated hydraulic conductivity value for 1d flow in a well
            min_resist_coeff: minimum well screen resistance value allowed for headloss across the screen
            node_point_tolerance: tolerance between map nodes and grid nodes for assigning RWs
            logger (Logger): Python logger.
            sort (bool): If true, BCs in the .3bc file are sorted which might better match GMS.
            query: object to communicate with GMS
        """
        self._att_table_coverage = att_table_coverage
        self._rw_coverage = bc_coverage
        self._co_grid_3d = co_grid_3d
        self._ugrid = co_grid_3d.ugrid
        self._cell_materials = cell_materials
        self._bc_filepath = bc_filepath
        self._max_sat_k = max_sat_k
        self._min_resist_coeff = min_resist_coeff
        self._node_point_tol = node_point_tolerance
        self._logger = logger if logger else logging.getLogger('xms.wash')
        self._query = query

        self._att_tables = {}  # dict[str, DataFrame] str = 'points', 'arcs', or 'polys'.
        self._xy_dict = {}  # dict[series id, XySeries] for all XySeries
        self._max_xyseries_id = 0  # maximum xy series id
        self._constant_vs_transient_arcs = set()
        self._ugrid_mapper = UgridMapper(co_grid_3d, cell_materials, self._logger, self._node_point_tol)
        self._sort = sort  # Sort the BCs. Can make output appear more like GMS.
        self._ugrid_to_mesh_face_for_wedge = [1, 0, 2, 3, 4]
        self._ugrid_to_mesh_face_for_hex = [5, 3, 2, 4, 0, 1]
        self._ugrid_to_mesh_face_for_tet_and_pyr = [2, 0, 1, 3]  # tets and pyramids are the same

        self._wells = {}  # Well. dict[int, list[Well]], PS1, PS2
        self._point_bcs = {}  # PointBc. dict[int, PointBc]. DB1, DB2
        self._face_bcs = {}  # FaceBc. dict[FaceLocation, PointBc]. CB1, CB2, RS1, RS2
        self._rs1_faces = set()  # All RS1 FaceLocations. Used when we write DB1 BCs
        self._relief_wells = []  # list of class ReliefWell
        # FaceLocation = tuple(int, int)  # type alias for the location of a face: cell index, face index.

    def map(self):
        """Does the mapping."""
        self._read_att_tables()
        self._read_xy_series()

        if self._att_table_coverage:
            self._map_well_points_to_points()  # PS1 and PS2
            self._map_head_polygons_to_top_points()  # DB1
            self._map_head_arcs_to_side_points()  # DB1
            self._map_conc_arcs_to_side_points()  # DB2
            self._map_flow_arcs_to_side_faces()  # CB1 and RS1
            self._map_mass_arcs_to_side_faces()  # CB2 and RS2
            self._map_mass_polygons_to_top_faces()  # CB2 and RS2
            self._map_flow_polygons_to_top_faces()  # CB1 and RS1
            self._remove_db1_from_rs1_cells()
        if self._rw_coverage:
            self._map_rw_points_to_points()  # RW1 cards
        self._write_bc_file()

    def _read_xy_series(self):
        """Reads all the xy series to self._xy_dict."""
        if self._att_table_coverage is None:
            return
        xy_list = self._att_table_coverage.get_xy_series()
        for xy_series in xy_list:
            self._xy_dict[xy_series.series_id] = xy_series
            if xy_series.series_id > self._max_xyseries_id:
                self._max_xyseries_id = xy_series.series_id

    def _read_att_tables(self):
        """Reads and stores all the attribute files."""
        if self._att_table_coverage is None:
            return
        self._att_tables['points'] = self._att_table_coverage.get_table('points')
        self._att_tables['arcs'] = self._att_table_coverage.get_table('arcs')
        self._att_tables['polys'] = self._att_table_coverage.get_table('polys')

    def _map_well_points_to_points(self):
        """Maps well points to UGrid points and creates PS1 (flow) and PS2 (concentration) BCs."""
        self._logger.info('Mapping wells to points...')
        fpoints = self._att_table_coverage.do_coverage.get_points(FilterLocation.PT_LOC_DISJOINT)
        point_atts = self._att_tables['points']
        for fpoint in fpoints:
            if get_att(point_atts, fpoint.id, FA_FW_TYPE) in {FA_TYPE_WELL, FA_TYPE_WELLSUPR}:
                # TODO: FA_TYPE_WELLSUPR
                screen_top = get_att(point_atts, fpoint.id, FA_FW_WELTOPSCR)
                screen_bottom = get_att(point_atts, fpoint.id, FA_FW_WELBOTSCR)
                name = get_att(point_atts, fpoint.id, FA_NAME)
                ugrid_points, t = self._ugrid_mapper.get_ugrid_points_at_point(
                    fpoint, z_min=screen_bottom, z_max=screen_top, inclusive=False
                )

                # Get flow and conc rates
                flow = get_att(point_atts, fpoint.id, FA_FW_WELFLOW)
                flow_xy = get_att(point_atts, fpoint.id, FA_FW_WELFLOWTS)
                conc = get_att(point_atts, fpoint.id, FA_FW_CONC)
                conc_xy = get_att(point_atts, fpoint.id, FA_FW_CONCTS)

                self._add_wells_at_points(flow, flow_xy, conc, conc_xy, ugrid_points, t, name)

    def _map_rw_points_to_points(self):
        """Maps relief well points to UGrid points and creates RW2 (passive flow)."""
        self._logger.info('Mapping relief wells to points...')
        do_comp = self._query.item_with_uuid(
            self._rw_coverage.uuid, model_name='Wash', unique_name="BcCoverageComponent"
        )
        bc_coverage_component = BcCoverageComponent(do_comp.main_file)
        self._query.load_component_ids(bc_coverage_component, points=True, arcs=False, polygons=False)
        generic_model = bc_coverage_component.data.generic_model
        gmi_section = generic_model.point_parameters
        fpoints = self._rw_coverage.get_points(FilterLocation.PT_LOC_DISJOINT)
        for fpoint in fpoints:
            comp_id = bc_coverage_component.get_comp_id(TargetType.point, fpoint.id)
            if comp_id is None or comp_id < 0:
                continue

            # Get the data for this feature
            att_type, values = bc_coverage_component.data.feature_type_values(TargetType.point, comp_id)
            if not values:
                continue  # pragma no cover - should never happen (? I think) and can't test

            gmi_section.restore_values(values)
            if not gmi_section.has_group(att_type):
                continue  # pragma no cover - should never happen and can't test
            group = gmi_section.group(att_type)

            self._map_rw_point(bc_coverage_component, fpoint, group)

    def _map_rw_bc_data(self, bc_coverage_component, group):
        """Create xyseries for relief well if transient.

        Args:
            bc_coverage_component (BcCoverageComponent): The component for boundary conditions.
            group (:obj:xms.gmi.components.pointparameters): Attributes applied to RW point

        Returns:
            head_bc (:obj:'Bc'): The RW BC object

        """
        head_bc = Bc()
        head_bc.bc_type = BcTypeEnum.RW1_relief_well
        head = group.parameter('head').value
        if head[0] == 1:
            head_bc.rate_type = RateTypeEnum.TRANSIENT
            head_bc.constant = 0.0
            xy_series = bc_coverage_component.data.get_curve(head[0], False)
            self._xy_dict[self._max_xyseries_id + 1] = XySeries(xy_series[0], xy_series[1])
            self._max_xyseries_id += 1
            head_bc.xy_series_id = self._max_xyseries_id
        else:
            head_bc.rate_type = RateTypeEnum.CONSTANT
            head_bc.constant = head[1]
        return head_bc

    def _map_rw_check_screen(self, points_and_z_values, ref_elev, lowest_pt, screens):
        """Validate the screen information and separate the nodes in the well from those in the screen.

         Args:
            points_and_z_values (List[int,float]): List of nodes under the RW with associated z values
            ref_elev (Float): Reference elevation for RW
            lowest_pt (Float): Lowest elevation for RW (either bottom of single screen, or bottom of lowest screen)
            screens (List[float, float]): List of top and bottom of each screen in the RW

        Returns:
            node_list (List[int]): list of nodes associated with each RW
            screen_flags (List[int]): list of flags for the node list (1 for in screen, 0 for in casing)

        """
        for screen in screens:
            if screen[0] < screen[1]:  # check to make sure top and bottom are in the correct order
                self._logger.error('Screen bottom must be below screen top')
        screen_flags = []
        node_list = []
        for point_and_z_val in points_and_z_values:
            if point_and_z_val[1] <= ref_elev:  # remove wells above the reference elevation
                if point_and_z_val[1] >= lowest_pt:  # remove wells below the lowest screen bottom
                    this_flag = 0  # flag of zero means node is in casing, but not in a screen
                    for screen in screens:
                        if point_and_z_val[1] <= screen[0]:
                            if point_and_z_val[1] >= screen[1]:
                                this_flag += 1  # flag of 1 means node is in a screen
                    if this_flag > 1:  # flag over 1 means it is in two screens, which is not allowed
                        self._logger.error('Screens cannot overlap')
                    else:
                        screen_flags.append(this_flag)  # 0's and 1's for each node in the casing
                        node_list.append(point_and_z_val[0])  # all nodes in casing☻
        return node_list, screen_flags

    def _map_rw_grid_pts(self, fpoint, group):
        """Get nodes under each RW and set reference elevation based on user input.

        Args:
            fpoint (:obj:xms.data_objects.parameters.Point | tuple[float, float]): The point to be mapped to the grid
            group (:obj:xms.gmi.components.pointparameters): Attributes applied to RW point

        Returns:
            node_list (List[int]): list of nodes associated with each RW
            screen_flags (List[int]): list of flags for the node list (1 for in screen, 0 for in casing)
            ref_elev (Float): reference elevation for passive flow RW
        """
        ref_elev_opt = group.parameter('ref_elev_option').value
        ugrid_points, _t = self._ugrid_mapper.get_ugrid_points_at_point(
            fpoint, inclusive=False
        )  # this gives all ugrid points at the well point
        z_vals = []
        for ugrid_point in ugrid_points:
            xyz = self._ugrid.locations[ugrid_point]
            z_vals.append(xyz[2])  # list of the elevation for each point
        points_and_z_vals = list(zip(ugrid_points, z_vals))  # combine node id's and z values into a single list
        points_and_z_vals.sort(key=lambda x: x[1], reverse=True)  # sort the list by z value to make sure it's in order
        if ref_elev_opt == 'Specify value':
            ref_elev = group.parameter('ref_elev').value
        else:
            # if the user wants to use the top node for the reference elevation, this sets it
            ref_elev = points_and_z_vals[0][1]
        well_screen_opts = group.parameter('well_screen_opts').value
        if well_screen_opts == 'One well screen':
            screen_top = group.parameter('screen_top').value
            screen_bottom = group.parameter('screen_bottom').value
            screen_list = [[screen_top, screen_bottom]]
            node_list, screen_flags = self._map_rw_check_screen(points_and_z_vals, ref_elev, screen_bottom, screen_list)
        else:  # user has selected multiple screens
            screen_list = group.parameter('well_screens').value
            lowest_screen_pt = min(min(screen_list))
            node_list, screen_flags = self._map_rw_check_screen(
                points_and_z_vals, ref_elev, lowest_screen_pt, screen_list
            )
        return node_list, screen_flags, ref_elev

    def _map_rw_point(self, bc_coverage_component, fpoint, group):
        """Function to fill the rw class with data for one RW.

        Args:
            bc_coverage_component (BcCoverageComponent): The component for boundary conditions.
            fpoint (:obj:xms.data_objects.parameters.Point | tuple[float, float]): The point to be mapped to the grid
            group (:obj:xms.gmi.components.pointparameters): Attributes applied to RW point

        """
        rw = ReliefWell()
        rw.head = self._map_rw_bc_data(bc_coverage_component, group)
        profile_type = group.parameter('profile_type').value
        if profile_type == "Pressure head (0)":
            rw.profile_type = 0
        else:  # can only be Total head (1)
            rw.profile_type = 1
        rw.sat_k = group.parameter('sat_k').value
        rw.well_diameter = group.parameter('well_diameter').value
        rw.well_resistance = group.parameter('well_screen_resistance_coeff').value
        rw.nodes, rw.screen_flags, rw.ref_elev = self._map_rw_grid_pts(fpoint, group)
        self._relief_wells.append(rw)

    def _map_head_polygons_to_top_points(self):
        """Maps the specified head polygons to the UGrid and creates DB1 BCs."""
        self._logger.info('Mapping specified head polygons to top points...')
        atts = self._att_tables['polys']
        for fpoly in self._att_table_coverage.do_coverage.polygons:
            if get_att(atts, fpoly.id, FA_FW_FLOWBC) == FA_FW_TSHEAD:
                points = self._ugrid_mapper.get_ugrid_top_points_in_polygon(fpoly)
                head1 = get_att(atts, fpoly.id, FA_FW_CHDHEAD)
                ts1 = get_att(atts, fpoly.id, FA_FW_CHDHEADTS)
                self._add_bcs_at_points(BcTypeEnum.DB1_spec_head, head1, ts1, points)

    def _map_head_arcs_to_side_points(self):
        """Maps the specified head arcs to the UGrid and creates DB1 BCs."""
        self._logger.info('Mapping specified head arcs to side points...')
        atts = self._att_tables['arcs']
        ugrid_locations_3d = self._ugrid.locations
        for farc in self._att_table_coverage.do_coverage.arcs:
            if get_att(atts, farc.id, FA_FW_FLOWBC) == FA_FW_TSHEAD:
                material = get_att(atts, farc.id, FA_FW_ZONE)
                points, t_values = self._ugrid_mapper.get_ugrid_points_on_arc(farc, exterior=True, material=material)
                head1 = get_att(atts, farc.id, FA_FW_CHDHEAD, node=0)
                head2 = get_att(atts, farc.id, FA_FW_CHDHEAD, node=1)
                ts1 = get_att(atts, farc.id, FA_FW_CHDHEADTS, node=0)
                ts2 = get_att(atts, farc.id, FA_FW_CHDHEADTS, node=1)

                last_loc = (XM_NONE, XM_NONE)
                last_t = XM_NONE
                last_ts = XM_NONE
                for point_idx, t in zip(points, t_values):
                    point_bc = self._point_bcs.get(point_idx)
                    if not point_bc:
                        point_bc = PointBc()
                        self._point_bcs[point_idx] = point_bc
                    head_bc = Bc(bc_type=BcTypeEnum.DB1_spec_head)
                    point_bc.head = head_bc

                    if ts1 == XM_NONE and ts2 == XM_NONE:
                        head_bc.rate_type = RateTypeEnum.CONSTANT
                        head_bc.constant = head1 + (t * (head2 - head1))
                    else:
                        head_bc.rate_type = RateTypeEnum.TRANSIENT
                        loc = ugrid_locations_3d[point_idx]
                        if loc[0] != last_loc[0] or loc[1] != last_loc[1] or t != last_t:
                            new_name = f'Head_{point_idx + 1}'
                            if ts1 == XM_NONE and ts2 != XM_NONE:  # one node is trans and the other is const
                                series = xy_util.interpolate(
                                    xy1=self._xy_dict[ts2], constant=head1, t=1 - t, xy_dict=self._xy_dict
                                )
                                series.name = new_name
                                self._xy_dict[series.series_id] = series
                                self._constant_vs_transient_arcs.add(farc)
                            elif ts1 != XM_NONE and ts2 == XM_NONE:  # one node is trans and the other is const
                                series = xy_util.interpolate(
                                    xy1=self._xy_dict[ts1], constant=head2, t=t, xy_dict=self._xy_dict
                                )
                                series.name = new_name
                                self._xy_dict[series.series_id] = series
                                self._constant_vs_transient_arcs.add(farc)
                            elif ts1 == ts2:  # pragma: no cover - xy series in coverages are never shared
                                series = self._xy_dict[ts1]
                                series.name = f'Head_{farc.id}'
                            else:
                                series = xy_util.interpolate(
                                    xy1=self._xy_dict[ts1], xy2=self._xy_dict[ts2], t=t, xy_dict=self._xy_dict
                                )
                                series.name = new_name
                                self._xy_dict[series.series_id] = series

                            head_bc.xy_series_id = series.series_id
                            last_loc = loc
                            last_t = t
                            last_ts = series.series_id
                        else:
                            head_bc.xy_series_id = last_ts

    def _map_conc_arcs_to_side_points(self):
        """Maps the specified concentration arcs to the UGrid and creates DB2 BCs."""
        self._logger.info('Mapping specified concentration arcs to side points...')
        atts = self._att_tables['arcs']
        for farc in self._att_table_coverage.do_coverage.arcs:
            type = get_att(atts, farc.id, FA_FW_TRANSBC)
            if type == FA_FW_TSCONC:
                material = get_att(atts, farc.id, FA_FW_ZONE)
                point_idxs, _ = self._ugrid_mapper.get_ugrid_points_on_arc(farc, exterior=True, material=material)
                conc = get_att(atts, farc.id, FA_FW_CONC)
                concts = get_att(atts, farc.id, FA_FW_CONCTS)
                self._add_bcs_at_points(BcTypeEnum.DB2_spec_conc, conc, concts, point_idxs)

    def _add_bcs_at_points(self, bc_type, constant_value, xy_series_id, point_indexes):
        """Adds Bc objects to bc_dict for each point in point_indexes.

        Args:
            bc_type (BcTypeEnum): The type of BC.
            constant_value (float): A constant value.
            xy_series_id (int): XySeries ID.
            point_indexes (list[int]): List of 3D UGrid point indexes where Bcs will be created.
        """
        last_xy = XM_NONE
        for point_idx in point_indexes:
            point_bc = self._point_bcs.get(point_idx)
            if not point_bc:
                point_bc = PointBc()
                self._point_bcs[point_idx] = point_bc
            bc = Bc(bc_type=bc_type)
            if bc_type == BcTypeEnum.DB1_spec_head:
                point_bc.head = bc
            else:
                point_bc.conc = bc
            last_xy = self._set_bc(bc, constant_value, xy_series_id, last_xy)

    def _add_wells_at_points(self, flow_constant, flow_xy, conc_constant, conc_xy, point_indexes, factors, name_att):
        """Adds Bc objects to bc_dict for each point in point_indexes.

        Args:
            flow_constant (float|None): The constant value for flow.
            flow_xy (int|None): The XySeries ID for flow.
            conc_constant (float|None): The constant value for conc.
            conc_xy (int|None): The XySeries ID for conc.
            point_indexes (list[int]): List of 3D UGrid point indexes where Bcs will be created.
            factors (list[float]): Factors (0.0 to 1.0) used to scale the value(s).
            name_att (str): Name from att table.
        """
        for i, point_idx in enumerate(point_indexes):
            well_list = self._wells.get(point_idx)
            well = Well()
            if not well_list:
                well_list = [well]
                self._wells[point_idx] = well_list
            else:
                well_list.append(well)
            if flow_constant is not None or flow_xy is not None:
                well.pump = Bc(bc_type=BcTypeEnum.PS1_well_flow)
                self._set_well_bc(well.pump, flow_constant, flow_xy, factors[i], name_att)
            if conc_constant is not None or conc_xy is not None:
                well.conc = Bc(bc_type=BcTypeEnum.PS2_well_conc)
                self._set_well_bc(well.conc, conc_constant, conc_xy, factors[i], name_att)

    def _add_bcs_at_faces(self, bc_type, constant_value, xy_series_id, faces):
        """Adds Bc's to bc_dict for each face.

        Args:
            bc_type (BcTypeEnum): The type of BC.
            constant_value (float): Constant value.
            xy_series_id (int): XySeries ID.
            faces (list[tuple(int, int)]): List of faces (cell index, face index).
        """
        last_xy = XM_NONE
        for cell_idx, face_idx in faces:
            face_bc = self._face_bcs.get((cell_idx, face_idx))
            if not face_bc:
                face_bc = FaceBc()
                self._face_bcs[(cell_idx, face_idx)] = face_bc
            bc = Bc(bc_type=bc_type)
            if bc_type in [BcTypeEnum.CB1_spec_flux, BcTypeEnum.RS1_var_flux]:
                face_bc.flux = bc
                if bc_type == BcTypeEnum.RS1_var_flux:
                    self._rs1_faces.add((cell_idx, face_idx))
            else:
                face_bc.conc = bc
            last_xy = self._set_bc(bc, constant_value, xy_series_id, last_xy)

    def _set_bc(self, bc, constant_value, xy_series_id, last_xy):
        """Sets the Bc members appropriately.

        Args:
            bc (Bc): The boundary condition object.
            constant_value (float): A constant value.
            xy_series_id (int): A time-varying xy series ID.
            last_xy (int): XY series ID we will use for all the Bcs, if we use one.

        Returns:
            (int): XY series ID we will use for all the Bcs, if we use one.
        """
        if xy_series_id == XM_NONE:
            bc.rate_type = RateTypeEnum.CONSTANT
            bc.constant = constant_value
        else:
            bc.rate_type = RateTypeEnum.TRANSIENT
            if last_xy == XM_NONE:
                last_xy = xy_util.add_or_match(self._xy_dict[xy_series_id], self._xy_dict)
            bc.xy_series_id = last_xy
        return last_xy

    def _set_well_bc(self, bc, constant_value, xy_series_id, factor, name_att):
        """Sets the Bc members appropriately.

        Args:
            bc (Bc): The boundary condition object.
            constant_value (float): A constant value.
            xy_series_id (int): A time-varying xy series ID.
            factor (float): Factor (0.0 to 1.0) used to scale the value(s).
            name_att (str): Name from att table.
        """
        if xy_series_id == XM_NONE:
            bc.rate_type = RateTypeEnum.CONSTANT
            bc.constant = factor * constant_value
        else:
            bc.rate_type = RateTypeEnum.TRANSIENT
            if factor != 1.0:
                new_series = copy.deepcopy(self._xy_dict[xy_series_id])
                xy_util.scale_y_data(new_series, factor)
                bc.xy_series_id = xy_util.add_or_match(new_series, self._xy_dict)
            else:
                bc.xy_series_id = xy_series_id

            self._xy_dict[bc.xy_series_id].name = self._well_xy_name(bc.xy_series_id, name_att)

    def _well_xy_name(self, xy_series_id, name_att):
        """Returns the name of the xy series for the well bc.

        Args:
            xy_series_id (int): The xy series id.
            name_att (str): Name from att table.

        Returns:
            (str): The name.
        """
        xy_name = self._xy_dict[xy_series_id].name
        if xy_name != 'Curve' and 'others' not in xy_name:
            xy_name += '_others'
        else:
            xy_name = name_att
        return xy_name

    def _map_flow_arcs_to_side_faces(self):
        """Maps the arcs to the UGrid and creates CB1 and RS1 BCs."""
        self._logger.info('Mapping specified flow arcs to side faces...')
        atts = self._att_tables['arcs']
        for farc in self._att_table_coverage.do_coverage.arcs:
            type_ = get_att(atts, farc.id, FA_FW_FLOWBC)
            if type_ == FA_FW_TSFLUX or type_ == FA_FW_TVFLUX:
                material = get_att(atts, farc.id, FA_FW_ZONE)
                faces = self._ugrid_mapper.get_ugrid_side_faces_on_arc(farc, material)
                constant_value = get_att(atts, farc.id, FA_FW_FLUX)
                xy_series_id = get_att(atts, farc.id, FA_FW_FLUXTS)
                if type_ == FA_FW_TSFLUX:  # CB1
                    self._add_bcs_at_faces(BcTypeEnum.CB1_spec_flux, constant_value, xy_series_id, faces)
                else:  # RS1
                    self._add_bcs_at_faces(BcTypeEnum.RS1_var_flux, constant_value, xy_series_id, faces)

    def _map_mass_arcs_to_side_faces(self):
        """Maps the specified mass arcs and polygons to the UGrid and creates CB2 and RS2 BCs.

        FA_FW_TSMFLUX = 'spec. mass flux', FA_FW_TVCONC = 'variable conc.'
        """
        self._logger.info('Mapping specified mass arcs to side faces...')
        atts = self._att_tables['arcs']
        for farc in self._att_table_coverage.do_coverage.arcs:
            type_ = get_att(atts, farc.id, FA_FW_TRANSBC)
            if type_ == FA_FW_TSMFLUX or type_ == FA_FW_TVCONC:
                material = get_att(atts, farc.id, FA_FW_ZONE)
                faces = self._ugrid_mapper.get_ugrid_side_faces_on_arc(farc, material)
                if type_ == FA_FW_TSMFLUX:  # CB2
                    constant_value = get_att(atts, farc.id, FA_FW_MASS)
                    xy_series_id = get_att(atts, farc.id, FA_FW_MASSTS)
                    self._add_bcs_at_faces(BcTypeEnum.CB2_spec_mass_flux, constant_value, xy_series_id, faces)
                else:  # RS2
                    constant_value = get_att(atts, farc.id, FA_FW_CONC)
                    xy_series_id = get_att(atts, farc.id, FA_FW_CONCTS)
                    self._add_bcs_at_faces(BcTypeEnum.RS2_var_conc, constant_value, xy_series_id, faces)

    def _map_mass_polygons_to_top_faces(self):
        """Maps the specified mass arcs and polygons to the UGrid and creates CB2 and RS2 BCs.

        FA_FW_TSMFLUX = 'spec. mass flux', FA_FW_TVCONC = 'variable conc.'
        """
        self._logger.info('Mapping specified mass polygons to top faces...')
        atts = self._att_tables['polys']
        for fpoly in self._att_table_coverage.do_coverage.polygons:
            type_ = get_att(atts, fpoly.id, FA_FW_TRANSBC)
            if type_ == FA_FW_TSMFLUX or type_ == FA_FW_TVCONC:
                faces = self._ugrid_mapper.get_ugrid_top_faces_in_polygon(fpoly)
                if type_ == FA_FW_TSMFLUX:
                    constant_value = get_att(atts, fpoly.id, FA_FW_MASS)
                    xy_series_id = get_att(atts, fpoly.id, FA_FW_MASSTS)
                    self._add_bcs_at_faces(BcTypeEnum.CB2_spec_mass_flux, constant_value, xy_series_id, faces)
                else:
                    constant_value = get_att(atts, fpoly.id, FA_FW_CONC)
                    xy_series_id = get_att(atts, fpoly.id, FA_FW_CONCTS)
                    self._add_bcs_at_faces(BcTypeEnum.RS2_var_conc, constant_value, xy_series_id, faces)

    def _map_flow_polygons_to_top_faces(self):
        """Maps the specified mass polygons to the UGrid and creates CB1 and RS1 BCs.

        FA_FW_TSFLUX = 'spec. flux', FA_FW_TVFLUX = 'variable flux'
        """
        self._logger.info('Mapping specified flow polygons to top faces...')
        atts = self._att_tables['polys']
        for fpoly in self._att_table_coverage.do_coverage.polygons:
            type_ = get_att(atts, fpoly.id, FA_FW_FLOWBC)
            if type_ == FA_FW_TSFLUX or type_ == FA_FW_TVFLUX:
                faces = self._ugrid_mapper.get_ugrid_top_faces_in_polygon(fpoly)
                constant_value = get_att(atts, fpoly.id, FA_FW_FLUX)
                xy_series_id = get_att(atts, fpoly.id, FA_FW_FLUXTS)
                if type_ == FA_FW_TSFLUX:
                    self._add_bcs_at_faces(BcTypeEnum.CB1_spec_flux, constant_value, xy_series_id, faces)
                else:
                    self._add_bcs_at_faces(BcTypeEnum.RS1_var_flux, constant_value, xy_series_id, faces)

    def _remove_db1_from_rs1_cells(self):
        """According to the WASH folks, DB1 BCs should not exist in cells that have RS1 BCs.

        There are a few ways to do this, and I don't know which is faster. The following seems pretty fast.
        """
        for cell_idx, face_idx in self._rs1_faces:
            face_points = self._ugrid.get_cell_3d_face_points(cell_idx, face_idx)
            for face_point in face_points:
                point_bc = self._point_bcs.get(face_point)
                if point_bc and point_bc.head and point_bc.head.bc_type == BcTypeEnum.DB1_spec_head:
                    self._point_bcs.pop(face_point)

    def bcs_of_type(self, bc_type):
        """Returns a list of all the Bcs of the given type.

        Args:
            bc_type (BcTypeEnum): The Bc type.

        Returns:
            (list[Bc]): See description.
        """
        bcs = []
        match bc_type:
            case BcTypeEnum.PS1_well_flow:
                bcs = [(pt_idx, w.pump) for pt_idx, well_list in self._wells.items() for w in well_list if w.pump]
            case BcTypeEnum.PS2_well_conc:
                bcs = [(pt_idx, w.conc) for pt_idx, well_list in self._wells.items() for w in well_list if w.conc]
            case BcTypeEnum.DB1_spec_head:
                bcs = [(pt_idx, pt_bc.head) for pt_idx, pt_bc in self._point_bcs.items() if pt_bc.head]
            case BcTypeEnum.DB2_spec_conc:
                bcs = [(pt_idx, p.conc) for pt_idx, p in self._point_bcs.items() if p.conc]
            case BcTypeEnum.CB1_spec_flux:
                bcs = [
                    (f_loc, f.flux) for f_loc, f in self._face_bcs.items() if f.flux and  # noqa: W504 (line break)
                    f.flux.bc_type == BcTypeEnum.CB1_spec_flux
                ]
            case BcTypeEnum.CB2_spec_mass_flux:
                bcs = [
                    (f_loc, f.conc) for f_loc, f in self._face_bcs.items() if f.conc and  # noqa: W504 (line break)
                    f.conc.bc_type == BcTypeEnum.CB2_spec_mass_flux
                ]
            case BcTypeEnum.RS1_var_flux:
                bcs = [
                    (f_loc, f.flux)
                    for f_loc, f in self._face_bcs.items() if f.flux and f.flux.bc_type == BcTypeEnum.RS1_var_flux
                ]
            case BcTypeEnum.RS2_var_conc:
                return [
                    (f_loc, f.conc)
                    for f_loc, f in self._face_bcs.items() if f.conc and f.conc.bc_type == BcTypeEnum.RS2_var_conc
                ]
            case BcTypeEnum.BcNone:  # pragma: no cover
                return []
        bcs.sort(key=lambda a: a[0])  # Sort by point index
        return bcs

    def _create_constant_xy_series(self, name, constant, factor):
        """Creates and returns an XySeries with one xy pair where y = constant * factor.

        Args:
            name (str): Name to give to new XySeries.
            constant (float): The constant.
            factor (float): A factor used to scale constant.

        Returns:
            (XySeries): The new XySeries.
        """
        return XySeries(x=[0.0, 1e99], y=[constant * factor, constant * factor], name=name, series_id=1)

    def _add_constant_xy_series(self, bc):
        """Adds a constant XY series if the bc is constant and an identical XY series doesn't exist yet.

        Args:
            bc (Bc): The boundary condition.
        """
        if bc and bc.rate_type == RateTypeEnum.CONSTANT:
            new_series = self._create_constant_xy_series('constant', bc.constant, bc.factor)
            bc.xy_series_id = xy_util.add_or_match(new_series, self._xy_dict)

    def _add_constant_xy_series_rw(self, relief_well):
        if relief_well.head.rate_type == RateTypeEnum.CONSTANT:
            new_series = self._create_constant_xy_series('relief well', relief_well.head.constant, 1.0)
            relief_well.head.xy_series_id = xy_util.add_or_match(new_series, self._xy_dict)

    def _add_all_constant_xy_series(self):
        """Creates a constant xy series, finds a matching series if any, and sets bc.xy_series_id if needed."""
        for well_list in self._wells.values():
            for well in well_list:
                self._add_constant_xy_series(well.pump)
                self._add_constant_xy_series(well.conc)
        for point_bc in self._point_bcs.values():
            self._add_constant_xy_series(point_bc.head)
            self._add_constant_xy_series(point_bc.conc)
        for face_bc in self._face_bcs.values():
            self._add_constant_xy_series(face_bc.flux)
            self._add_constant_xy_series(face_bc.conc)
        for relief_well in self._relief_wells:
            self._add_constant_xy_series_rw(relief_well)

    def _add_xy_series_if_needed(self, bc, xy_series_set):
        """Adds the xy series id to the set if the bc exists and has an xy series id.

        Args:
            bc (Bc): The Bc.
            xy_series_set (set{int}): Set of xy series IDs.
        """
        # Don't check bc.rate_type == RateTypeEnum.TRANSIENT because constant bcs use xy_series_id at this point
        if bc and bc.xy_series_id != XM_NONE:
            xy_series_set.add(bc.xy_series_id)

    def _get_xy_series_ids_in_use(self):
        """Returns the set of unique xy series IDs in use by the BCs."""
        xy_series_set = set()
        for well_list in self._wells.values():
            for well in well_list:
                self._add_xy_series_if_needed(well.pump, xy_series_set)
                self._add_xy_series_if_needed(well.conc, xy_series_set)
        for point_bc in self._point_bcs.values():
            self._add_xy_series_if_needed(point_bc.head, xy_series_set)
            self._add_xy_series_if_needed(point_bc.conc, xy_series_set)
        for face_bc in self._face_bcs.values():
            self._add_xy_series_if_needed(face_bc.flux, xy_series_set)
            self._add_xy_series_if_needed(face_bc.conc, xy_series_set)
        for relief_well in self._relief_wells:
            xy_series_set.add(relief_well.head.xy_series_id)
        return xy_series_set

    def _write_xy_series_for_bcs(self, file):
        """Writes all the BC xy series to the file."""
        xy_series_set = self._get_xy_series_ids_in_use()
        for xy_series_id in xy_series_set:
            self._xy_dict[xy_series_id].write(file, xy1_format=True)

    def _write_wells(self, file):
        """Writes the wells to the file."""
        # wells = self._get_all_wells()
        wells = dict(sorted(self._wells.items(), reverse=True)) if self._sort else self._wells
        for location, well_list in wells.items():
            for well in well_list:
                for bc in [well.pump, well.conc]:
                    if bc and bc.bc_type != BcTypeEnum.BcNone:
                        file.write(f'{bc.bc_type} {location + 1} {bc.xy_series_id}\n')

    def _write_relief_wells(self, file):
        """Write the relief wells to the file.

        Args:
            file: The file being written to.
        """
        if not self._relief_wells:
            return
        file.write(f'RW0 {self._max_sat_k} {self._min_resist_coeff}\n')
        rw_counter = 1
        for rw_well in self._relief_wells:
            if len(rw_well.nodes) > 0:  # check to make sure at least one node was found
                file.write(
                    f'{rw_well.head.bc_type} {rw_counter} {len(rw_well.nodes)} {sum(rw_well.screen_flags)} '
                    f'{rw_well.head.xy_series_id} {rw_well.profile_type} {rw_well.ref_elev} '
                    f'{rw_well.sat_k} {rw_well.well_diameter} {rw_well.well_resistance}\n'
                )
                out_string = ""
                for node in rw_well.nodes:
                    # write list of all nodes in casing
                    out_string += str(node + 1)
                    out_string += " "
                file.write(f'{out_string.rstrip()}\n')
                out_string = ""
                for flag in rw_well.screen_flags:
                    # flags are 1 for nodes in screens and 0 for nodes in casing
                    out_string += str(flag)
                    out_string += " "
                file.write(f'{out_string.rstrip()}\n')
                rw_counter += 1

    def _write_boundary_points(self, file):
        """Writes the boundary point BCs to the file.

        Args:
            file: The file being written to.
        """
        point_bcs = dict(sorted(self._point_bcs.items(), reverse=True)) if self._sort else self._point_bcs
        for location, point_bc in point_bcs.items():
            for bc in [point_bc.head, point_bc.conc]:
                if bc and bc.bc_type != BcTypeEnum.BcNone:
                    # write 'type-card cell-id xy-series-id' e.g. 'DB1 3 5'
                    file.write(f'{bc.bc_type} {location + 1} {bc.xy_series_id}\n')

    def _order_faces(self):
        """Return the faces sorted by cell index, then face index, if we're sorting, otherwise return the dict as is.

        Returns:
            (dict[tuple[int, int], FaceBc]): Faces dict.
        """
        if self._sort:
            return dict(sorted(self._face_bcs.items(), reverse=True, key=lambda pair: (pair[0], pair[1])))
        else:
            return self._face_bcs

    def _mesh_face_index_from_ugrid_face_index(self, cell_idx, ugrid_face_index):
        """Returns the mesh face index that corresponds to the UGrid face index for the given cell.

        Args:
            cell_idx (int): UGrid cell index.
            ugrid_face_index (int): UGrid face index.

        Returns:
            (int): Mesh face index
        """
        cell_type = self._ugrid.get_cell_type(cell_idx)
        match cell_type:
            case UGrid.cell_type_enum.WEDGE:
                return self._ugrid_to_mesh_face_for_wedge[ugrid_face_index]
            case UGrid.cell_type_enum.HEXAHEDRON:
                return self._ugrid_to_mesh_face_for_hex[ugrid_face_index]
            case UGrid.cell_type_enum.TETRA | UGrid.cell_type_enum.PYRAMID:
                return self._ugrid_to_mesh_face_for_tet_and_pyr[ugrid_face_index]
            case _:
                raise RuntimeError(
                    f'Could not get face index for cell {cell_idx}. Only tetrahedrons/pyramids'
                    f' (4 sides), prisms/wedges (5 sides) and hexahedrons (6 sides) are supported.'
                )

    def _write_boundary_faces(self, file, bc_types):
        """Writes the boundary face BCs to the file.

        Args:
            file: The file being written to.
            bc_types (set{BcTypeEnum}): Set of bc types to write.
        """
        # Sort by cell index, then face index
        face_bcs = self._order_faces()
        for location, face_bc in face_bcs.items():
            for bc in [face_bc.flux, face_bc.conc]:
                if bc and bc.bc_type in bc_types:
                    mesh_face_index = self._mesh_face_index_from_ugrid_face_index(location[0], location[1])
                    # write 'type-card cell-id face-id xy-series-id' e.g. 'CB1 3 5 2'
                    file.write(f'{bc.bc_type} {location[0] + 1} {mesh_face_index + 1} {bc.xy_series_id}\n')

    def _write_bc_file(self):
        """Writes the BC file."""
        self._logger.info(f'Writing the bc file {self._bc_filepath}')
        self._add_all_constant_xy_series()
        with self._bc_filepath.open('w') as file:
            file.write('WMS3BC\n')
            self._write_xy_series_for_bcs(file)
            self._write_wells(file)
            self._write_relief_wells(file)
            self._write_boundary_faces(file, {BcTypeEnum.RS1_var_flux, BcTypeEnum.RS2_var_conc})
            self._write_boundary_points(file)
            self._write_boundary_faces(file, {BcTypeEnum.CB1_spec_flux, BcTypeEnum.CB2_spec_mass_flux})
            file.write('END\n')
