"""Class for snapping bcs to model grid."""
__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules

# 3. Aquaveo modules
from xms.data_objects.parameters import FilterLocation
from xms.guipy.data.target_type import TargetType
from xms.snap.snap_exterior_arc import SnapExteriorArc
from xms.snap.snap_interior_arc import SnapInteriorArc

# 4. Local modules
from xms.rsm.data import bc_data_def as bcdd
from xms.rsm.file_io import util


class MeshBcSnapper:
    """Snaps BC Coverages to the UGrid.."""
    def __init__(self, xms_data):
        """Constructor.

        Args:
            xms_data (XmsData): Simulation data retrieved from SMS
        """
        self._logger = util.get_logger()
        self._xms_data = xms_data
        self.pt_grid_cell_snap = {}
        self.arc_grid_pt_snap = {}
        self._ug = self._xms_data.xmugrid
        self._ug_locs = self._ug.locations
        self._extractor = None
        self._arc_snap_int = None
        self._arc_snap_ext = None
        self._bc_cov = None
        self._bc_comp = None
        self._arc_types = bcdd.ARC_TYPES
        self._pt_types = bcdd.PT_TYPES

    def generate_snap(self):
        """Write the mesh bc portion of the control file."""
        # intersect coverages with the grid/mesh
        if not self._xms_data.bc_coverages:
            return
        self._logger.info('Processing Boundary Condition coverages.')
        self._extractor = self._xms_data.ugrid_extractor
        self._arc_snap_int = SnapInteriorArc()
        self._arc_snap_int.set_grid(self._xms_data.cogrid, target_cells=False)
        self._arc_snap_ext = SnapExteriorArc()
        self._arc_snap_ext.set_grid(self._xms_data.cogrid, target_cells=False)

        for cov_comp in self._xms_data.bc_coverages:
            self._bc_cov = cov_comp[0]
            self._bc_comp = cov_comp[1]
            self._logger.info(f'Processing coverage: {cov_comp[0].name}')
            self._process_cov_arcs()
            self._process_cov_points()

    def _process_cov_points(self):
        """Process the points in the current coverage."""
        pts = self._bc_cov.get_points(FilterLocation.PT_LOC_DISJOINT)
        self._extractor.extract_locations = [(p.x, p.y, p.z) for p in pts]
        self._extractor.extract_data()
        cell_idx = self._extractor.cell_indexes
        gm_pt = bcdd.generic_model().point_parameters
        self.pt_grid_cell_snap[self._bc_cov.uuid] = {}
        for p, idx in zip(pts, cell_idx):
            comp_id = self._bc_comp.get_comp_id(TargetType.point, p.id)
            if comp_id is None or comp_id < 0:
                continue
            comp_type, comp_val = self._bc_comp.data.feature_type_values(TargetType.point, comp_id)
            if idx < 1:
                self._logger.warning(f'Point id: {p.id} is not inside of the grid and was skipped.')
            elif comp_type not in self._pt_types:
                self._logger.warning(f'Point id: {p.id} is not assigned a valid type and was skipped.')
            else:
                self.pt_grid_cell_snap[self._bc_cov.uuid][p.id] = {
                    'gm_grp': gm_pt.group(comp_type),
                    'cov_name': self._bc_cov.name.replace(' ', '_'),
                    'pt_id': p.id,
                    'cell_idx': idx,
                    'comp_val': comp_val,
                    'comp_type': comp_type,
                    'location': self._ug.get_cell_centroid(idx)[1]
                }

    def _process_cov_arcs(self):
        """Process the arcs in the current coverage."""
        self.arc_grid_pt_snap[self._bc_cov.uuid] = {}
        for arc in self._bc_cov.arcs:
            comp_id = self._bc_comp.get_comp_id(TargetType.arc, arc.id)
            if comp_id is None or comp_id < 0:
                continue

            comp_type, comp_val = self._bc_comp.data.feature_type_values(TargetType.arc, comp_id)
            o = self._arc_snap_int.get_snapped_points(arc)
            if len(o['id']) < 2:
                o = self._arc_snap_ext.get_snapped_points(arc)
            if len(o['id']) < 2:
                self._logger.warning(f'Arc id: {arc.id} does not snap to 2 more grid locations and was skipped.')
            elif comp_type not in self._arc_types:
                self._logger.warning(f'Arc id: {arc.id} is not assigned a valid type and was skipped.')
            else:
                locs = [xyz for idx in o['id'] for xyz in self._ug_locs[idx]]
                self.arc_grid_pt_snap[self._bc_cov.uuid][arc.id] = {
                    'comp_type': comp_type,
                    'comp_val': comp_val,
                    'cov_name': self._bc_cov.name.replace(' ', '_'),
                    'arc_id': arc.id,
                    'pt_string': o['id'],
                    'locations': locs
                }
