"""Map Monitor coverage locations and attributes to the SRH-2D domain."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import logging
import math
import os
import shutil
import uuid

# 2. Third party modules

# 3. Aquaveo modules
from xms.components.display.display_options_io import read_display_options_from_json, write_display_options_to_json
from xms.components.display.display_options_io import write_display_option_line_locations
from xms.data_objects.parameters import Component, FilterLocation
from xms.extractor.ugrid_2d_data_extractor import UGrid2dDataExtractor
from xms.guipy.data.category_display_option_list import CategoryDisplayOptionList
from xms.guipy.data.line_style import LineStyle
from xms.guipy.data.target_type import TargetType
from xms.snap.snap_interior_arc import SnapInteriorArc

# 4. Local modules
from xms.srh.components.mapped_monitor_component import MappedMonitorComponent


class MonitorMapper:
    """Class for mapping monitor coverage to a mesh for SRH-2D."""
    def __init__(self, coverage_mapper, wkt, generate_snap, comp_path, structures_3d_monitor):
        """Constructor."""
        self._generate_snap = generate_snap
        self._logger = logging.getLogger('xms.srh')
        self._co_grid = coverage_mapper.co_grid
        self._new_comp_unique_name = 'Mapped_Monitor_Component'
        self._monitor_coverage = coverage_mapper.monitor_coverage
        self._monitor_component = coverage_mapper.monitor_component
        self._snap_arc_interior = SnapInteriorArc()
        self._snap_arc_interior.set_grid(grid=self._co_grid, target_cells=False)
        self._comp_main_file = ''
        self._arc_to_grid_points = []
        self.arc_id_to_grid_ids = {}
        self._monitor_points = []
        self._comp_path = comp_path
        self.mapped_comp_uuid = None
        self.mapped_comp_display_uuid = None
        self.grid_wkt = wkt
        self.structures_3d_monitor = structures_3d_monitor
        self.out_3d_structures = []

    def do_map(self):
        """Creates the mapped monitor component."""
        self._get_coverage_points()
        self._get_grid_points_from_arcs()
        if self._generate_snap and self._monitor_coverage:
            if self._arc_to_grid_points:
                self._create_component_folder_and_copy_display_options()
                self._create_drawing()

                # Create the data_objects component
                comp_name = f'Snapped {self._monitor_coverage.name} arc display'
                comp_uuid = os.path.basename(os.path.dirname(self._comp_main_file))
                do_comp = Component(
                    main_file=self._comp_main_file,
                    name=comp_name,
                    model_name='SRH-2D',
                    unique_name=self._new_comp_unique_name,
                    comp_uuid=comp_uuid
                )

                comp = MappedMonitorComponent(self._comp_main_file)
                return do_comp, comp
            else:
                msg = f'Coverage "{self._monitor_coverage.name}" does not contain any arcs that snap to the ' \
                      f'mesh. This coverage will not be added as a snap preview.'
                self._logger.warning(msg)
        return None, None  # pragma: no cover

    def _create_drawing(self):
        """Uses cell ids to get cell point coords to draw polygons for bc mapped to cells."""
        if self._arc_to_grid_points:
            filename = os.path.join(self._comp_path, 'monitor.display_ids')
            write_display_option_line_locations(filename, self._arc_to_grid_points)

    def _get_coverage_points(self):
        """Gets the points from the coverage."""
        points = []
        if self._monitor_coverage:
            points = self._monitor_coverage.get_points(FilterLocation.PT_LOC_DISJOINT)
        if len(points) < 1:  # pragma: no cover
            return
        pt_id = 1
        locs = []
        for point in points:
            comp_id = -1
            if self._monitor_component:
                comp_id = self._monitor_component.get_comp_id(TargetType.point, point.id)
                comp_id = -1 if comp_id is None else comp_id
            locs.append((point.x, point.y, 0.0))
            obs = False
            wse_val = 0.0
            wse_weight = 0.0
            if comp_id >= 0:
                monitor_data = self._monitor_component.data.monitor_point_param_from_id(comp_id)
                obs = monitor_data.observation_point
                wse_val = monitor_data.water_surface_elevation
                standard_dev = monitor_data.water_surface_elevation_interval / 1.96
                wse_weight = 1 / (standard_dev * standard_dev)
            self._monitor_points.append((pt_id, point.x, point.y, obs, wse_val, wse_weight))
            pt_id += 1
        ug = self._co_grid.ugrid
        ug_elevations = [pt[2] for pt in ug.locations]
        ug_activity = [1] * len(ug_elevations)
        extractor = UGrid2dDataExtractor(ug)
        extractor.set_grid_point_scalars(ug_elevations, ug_activity, 'points')
        extractor.extract_locations = locs
        out_elev = extractor.extract_data()
        for idx, val in enumerate(out_elev):
            if math.isnan(val):
                msg = f'Monitor point id: {idx+1} at location {locs[idx]} is outside the mesh.'
                self._logger.error(msg)
                msg = 'Fix the monitor point location or SRH will not run.'
                self._logger.error(msg)

    def _get_grid_points_from_arcs(self):
        """Uses xmssnap to get the points from arcs."""
        arcs = []
        if self._monitor_coverage:
            self._logger.info('Monitor coverage to mesh.')
            arcs = self._monitor_coverage.arcs
        arc_cnt = 0
        for arc in arcs:
            arc_cnt += 1
            snap_output = _snap_arc_to_grid(self._snap_arc_interior, arc)
            if 'location' not in snap_output:
                continue  # pragma: no cover
            pt_ids = self.arc_id_to_grid_ids[arc_cnt] = snap_output['id']
            points = [item for sublist in snap_output['location'] for item in sublist]
            if not _num_snap_points_is_valid(len(pt_ids), arc.id, self._monitor_coverage.name, self._logger):
                continue
            self._arc_to_grid_points.append(points)

        for struct in self.structures_3d_monitor:
            cov_name = struct['cov'].name
            mon_loc = {'up_arc': 'MONITOR_UPSTREAM', 'down_arc': 'MONITOR_DOWNSTREAM', 'mid_arc': 'MONITOR_MIDDLE'}
            for arc_loc in ['up_arc', 'down_arc', 'mid_arc']:
                if arc_loc != 'mid_arc':
                    continue  # decided that we only wanted the middle monitor
                arc_pts = struct[arc_loc]
                arc_pts.reverse()
                snap_output = self._snap_arc_interior.get_snapped_points(arc_pts)
                pt_ids = snap_output['id']
                points = [item for sublist in snap_output['location'] for item in sublist]
                if not _num_structure_snap_points_is_valid(len(pt_ids), arc_loc, cov_name, self._logger):
                    continue
                arc_cnt += 1
                self.arc_id_to_grid_ids[arc_cnt] = pt_ids
                self._arc_to_grid_points.append(points)
                self.out_3d_structures.append((-struct['arc'].id, mon_loc[arc_loc], arc_cnt, struct['cov'].uuid))

    def _create_component_folder_and_copy_display_options(self):
        """Creates the folder for the mapped bc component and copies the display options from the bc coverage."""
        if self.mapped_comp_uuid is None:
            comp_uuid = str(uuid.uuid4())  # pragma: no cover
        else:
            comp_uuid = self.mapped_comp_uuid
        self._logger.info('Creating component folder')
        self._comp_path = os.path.join(self._comp_path, comp_uuid)

        if os.path.exists(self._comp_path):
            shutil.rmtree(self._comp_path)  # pragma: no cover
        os.mkdir(self._comp_path)

        src_display_file = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'gui', 'resources', 'default_data',
            'default_monitor_display_options.json'
        )
        comp_display_file = os.path.join(self._comp_path, 'monitor_display_options.json')
        if os.path.isfile(src_display_file):
            shutil.copyfile(src_display_file, comp_display_file)
            categories = CategoryDisplayOptionList()  # Generates a random UUID key for the display list
            json_dict = read_display_options_from_json(comp_display_file)
            if self.mapped_comp_display_uuid is None:
                json_dict['uuid'] = str(uuid.uuid4())  # pragma: no cover
            else:
                json_dict['uuid'] = self.mapped_comp_display_uuid
            json_dict['comp_uuid'] = comp_uuid
            json_dict['is_ids'] = 0
            categories.from_dict(json_dict)

            # Set all snapped arcs to be dashed and thick by default. Keep current color.
            for category in categories.categories:
                category.options.style = LineStyle.DASHEDLINE
                category.options.width = 4
                category.label_on = False

            categories.projection = {'wkt': self.grid_wkt}
            write_display_options_to_json(comp_display_file, categories)
            self._comp_main_file = comp_display_file
        else:
            self._logger.info('Could not find monitor_display_options.json file')  # pragma: no cover


def _snap_arc_to_grid(snap_arc_interior, arc):
    """Snaps an arc to the grid."""
    return snap_arc_interior.get_snapped_points(arc)


def _num_snap_points_is_valid(num_nodes, arc_id, cov_name, logger):
    """Check the number of grid points in the monitor snap and log errors if needed.

    Args:
        num_nodes (int): The number of nodes in the monitor snap.
        arc_id (int): The arc id.
        cov_name (str): The coverage name.
        logger (logging.Logger): The logger.

    Returns:
        (bool): True if the number of nodes is valid.
    """
    if num_nodes < 1:
        msg = f'Arc id: {arc_id} in coverage: "{cov_name}" does not snap ' \
              f'to the mesh. Edit this monitor line so that it snaps to the mesh or remove it from ' \
              f'the coverage.'
        logger.error(msg)
        return False
    elif num_nodes < 3:
        msg = f'Arc id: {arc_id} in coverage: "{cov_name}" does not snap ' \
              f'to at least 3 mesh points. Edit this monitor line so that it snaps to at least 3 points.'
        logger.error(msg)
        return False
    elif num_nodes > 1000:
        msg = f'Arc id: {arc_id} in coverage: "{cov_name}" snaps to more ' \
              f'than 1000 points on the mesh. SRH-2D limits the number of nodes in a monitor ' \
              f'nodestring to 1000. Edit this monitor line so that it snaps to fewer than 1000 ' \
              f'points.'
        logger.error(msg)
        return False
    return True


def _num_structure_snap_points_is_valid(num_nodes, arc_loc, cov_name, logger):
    """Check the number of grid points in the monitor snap and log errors if needed.

    Args:
        num_nodes (int): The number of nodes in the monitor snap.
        arc_loc (str): Either 'up_arc' or 'down_arc'.
        cov_name (str): The coverage name.
        logger (logging.Logger): The logger.

    Returns:
        (bool): True if the number of nodes is valid.
    """
    lookup = {'up_arc': 'Upstream', 'down_arc': 'Downstream', 'mid_arc': 'Centerline'}
    loc = lookup[arc_loc]
    if num_nodes < 3:
        msg = f'{loc} arc in 3D Structure coverage: "{cov_name}" does not snap ' \
              f'to at least 3 mesh points. Skipping this monitor line.'
        logger.warning(msg)
        return False
    elif num_nodes > 1000:
        msg = f'{loc} arc in 3D Structure coverage: "{cov_name}" snaps to more ' \
              f'than 1000 points on the mesh. SRH-2D limits the number of nodes in a monitor ' \
              f'nodestring to 1000. Skipping this monitor line.'
        logger.warning(msg)
        return False
    return True
