"""Map Boundary Conditions coverage locations and attributes to the SRH-2D domain."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os
import shutil
import uuid

# 2. Third party modules
from shapely.geometry import LineString

# 3. Aquaveo modules
from xms.components.display.display_options_io import read_display_options_from_json, write_display_options_to_json
from xms.components.display.display_options_io import write_display_option_line_locations
from xms.data_objects.parameters import Component
from xms.guipy.data.category_display_option_list import CategoryDisplayOptionList
from xms.guipy.data.line_style import LineStyle
from xms.guipy.data.target_type import TargetType
from xms.snap.snap_exterior_arc import SnapExteriorArc
from xms.snap.snap_interior_arc import SnapInteriorArc

# 4. Local modules
from xms.srh.components.mapped_bc_component import MappedBcComponent
from xms.srh.data.bc_data import BcData
from xms.srh.data.par.bc_data_param import BcDataParam


class BcMapper:
    """Class for mapping bc coverage to a mesh for SRH-2D."""
    def __init__(self, coverage_mapper, wkt, generate_snap):
        """Constructor."""
        self._generate_snap = generate_snap
        self._logger = coverage_mapper._logger
        self._co_grid = coverage_mapper.co_grid
        self._bc_component_file = coverage_mapper.bc_component_file
        self._coverage_xml_str = 'bc_coverage'
        self._coverage_comp_xml_str = 'SRH-2D#Bc_Component'
        self._new_comp_unique_name = 'Mapped_Bc_Component'
        self._bc_coverage = coverage_mapper.bc_coverage
        self._bc_component = coverage_mapper.bc_component
        self._snap_arc = SnapExteriorArc()
        self._snap_arc.set_grid(grid=self._co_grid, target_cells=False)
        self._snap_arc_interior = SnapInteriorArc()
        self._snap_arc_interior.set_grid(grid=self._co_grid, target_cells=False)
        self._comp_main_file = ''
        self._arc_to_grid_points = {}
        self.arc_id_to_grid_ids = {}
        self.arc_id_to_node_string_length = {}
        self.arc_id_to_grid_pts = {}
        self._arc_id_to_comp_id = {}
        self._arc_id_to_bc_id = {}
        self._arc_id_to_bc_param = {}
        self._structures = {}
        self._comp_path = ''
        self._grid_wkt = wkt
        self.bc_mapped_comp_uuid = None
        self.bc_mapped_comp_display_uuid = None
        self.ceiling_file = None
        self.structures_3d_weirs = None
        self.out_3d_structures = []

    def do_map(self):
        """Creates the mapped bc component."""
        self._get_grid_points_from_arcs()
        self._map_3d_structures()
        if self._generate_snap:
            self._create_component_folder_and_copy_display_options()
            self._create_drawing()
            # Create the data_objects component
            comp_name = f'Snapped {self._bc_coverage.name} display'
            comp_uuid = os.path.basename(os.path.dirname(self._comp_main_file))
            do_comp = Component(
                main_file=self._comp_main_file,
                name=comp_name,
                model_name='SRH-2D',
                unique_name=self._new_comp_unique_name,
                comp_uuid=comp_uuid
            )
            comp = MappedBcComponent(self._comp_main_file)
            return do_comp, comp
        return None, None  # pragma: no cover

    def _create_drawing(self):
        """Uses cell ids to get cell point coords to draw polygons for bc mapped to cells."""
        for disp_name, arcs_list in self._arc_to_grid_points.items():
            filename = os.path.join(self._comp_path, f'display_ids/{disp_name}.display_ids')
            write_display_option_line_locations(filename, arcs_list)

    def _map_3d_structures(self):
        """Maps 3D Structures as weirs."""
        for struct in self.structures_3d_weirs:
            arc = struct['arc']
            cov_name = struct['cov'].name
            arc_pts = struct['up_arc']
            snap_output = self._snap_arc_interior.get_snapped_points(arc_pts)
            if self._snap_output_has_errors(snap_output, arc.id, cov_name):
                continue
            up_points = [item for sublist in snap_output['location'] for item in sublist]
            up_ids = snap_output['id']

            arc_pts = struct['down_arc']
            snap_output = self._snap_arc_interior.get_snapped_points(arc_pts)
            if self._snap_output_has_errors(snap_output, arc.id, cov_name):
                continue
            down_points = [item for sublist in snap_output['location'] for item in sublist]
            down_ids = snap_output['id']

            my_dict = {
                'cov_uuid': struct['cov'].uuid,
                'bc_data': struct['bc_data'],
                'arc_id': arc.id,
                'up_snap_ids': up_ids,
                'down_snap_ids': down_ids,
                'monitor_up': struct.get('monitor_up', -1),
                'monitor_down': struct.get('monitor_down', -1),
                'monitor_middle': struct.get('monitor_middle', -1),
            }
            self.out_3d_structures.append(my_dict)

            up_disp, down_disp = BcData.display_list[13]
            d_list = self._arc_to_grid_points.get(up_disp, [])
            d_list.append(up_points)
            self._arc_to_grid_points[up_disp] = d_list
            d_list = self._arc_to_grid_points.get(down_disp, [])
            d_list.append(down_points)
            self._arc_to_grid_points[down_disp] = d_list

    def _snap_output_has_errors(self, snap_output, arc_id, cov_name):
        """Checks for errors in the snap output.

        Args:
            snap_output (:obj:`dict`): output from snapper
            arc_id (:obj:`int`): arc id
            cov_name (:obj:`str`): coverage name

        Returns:
            (:obj:`bool`): True if there are snap errors
        """
        if 'location' not in snap_output or len(snap_output['location']) < 1:
            self._logger.warning(f'Unable to snap arc id: {arc_id} (Coverage: {cov_name}) to mesh.')
            return True
        elif len(snap_output['location']) < 3:
            msg1 = ' A boundary condition should snap to at least 3 nodes. Adjust the location of the arc.'
            msg = f'Arc id: {arc_id} (Coverage: {cov_name})snaps to less than 3 nodes on the mesh.' + msg1
            if len(snap_output['location']) == 1:
                msg = f'Arc id: {arc_id} (Coverage: {cov_name}) snaps to a single node.' + msg1
                self._logger.error(msg)
                return True
            else:
                self._logger.warning(msg)
        elif len(snap_output['location']) > 1000:
            num = 500
            msg = f'Arc id {arc_id} in Coverage: "{cov_name}" snaps to more than {num} points on the ' \
                  f'mesh. SRH-2D limits the number of nodes in a boundary condition nodestring to ' \
                  f'{num}. Edit this boundary condition arc so that it snaps to fewer than {num} nodes.'
            self._logger.error(msg)
            return True

        return False

    def _get_grid_points_from_arcs(self):
        """Uses xmssnap to get the points from arcs."""
        self._logger.info('Bc coverage to mesh.')
        structure_arcs = {}
        arc_index_to_grid_pts = {}
        cov_name = self._bc_coverage.name
        arcs = self._bc_coverage.arcs
        bc_data = self._bc_component.data
        df = bc_data.comp_ids.to_dataframe()
        comp_id_to_bc_id_name = {x[0]: (x[1], x[2]) for x in [tuple(x) for x in df.to_numpy()]}
        # df_bc = bc_data.bc_data.to_dataframe()
        # names, values = param_util.param_class_to_names_values(BcDataParam())
        arc_index = 0
        arc_index_to_arc_id = {}
        for arc in arcs:
            arc_index += 1
            arc_index_to_arc_id[arc_index] = arc.id
            bc_par = BcDataParam()
            comp_id = self._bc_component.get_comp_id(TargetType.arc, arc.id)
            bc_id = -1
            if comp_id is None or comp_id < 0:
                comp_id = 0
                display_name = 'wall'
            else:
                bc_id = comp_id_to_bc_id_name[comp_id][0]
                self._arc_id_to_bc_id[arc_index] = bc_id
                # record = df_bc.loc[df_bc['id'] == bc_id].reset_index(drop=True).to_dict()
                bc_par = self._bc_component.data.bc_data_param_from_id(bc_id)
                display_name = comp_id_to_bc_id_name[comp_id][1]

            struct_list = ['upstream', 'downstream']
            if any(item in display_name for item in struct_list):
                if bc_id not in self._structures.keys():
                    self._structures[bc_id] = {'up': -1, 'down': -1}
                if 'upstream' in display_name:
                    self._structures[bc_id]['up'] = arc_index
                elif 'downstream' in display_name:
                    self._structures[bc_id]['down'] = arc_index
                snap_output = self._snap_arc_interior.get_snapped_points(arc)
                arc_list = structure_arcs.get(bc_id, list())
                arc_list.append(arc_index)
                structure_arcs[bc_id] = arc_list
            elif display_name in ['internal_sink', 'bc_data']:
                snap_output = self._snap_arc_interior.get_snapped_points(arc)
            else:
                snap_output = self._snap_arc.get_snapped_points(arc)

            if self._snap_output_has_errors(snap_output, arc.id, cov_name):
                continue

            if self.ceiling_file and bc_par.bc_type == 'Pressure':
                msg = f'Error: pressure structure and ceiling file in same simulation.\n' \
                      f'Simulation has a ceiling file and the bc coverage has a pressure structure on ' \
                      f'arc id: {arc.id}. The pressure structures will be ignored.'
                self._logger.error(msg)
            self._arc_id_to_bc_param[arc_index] = bc_par
            self.arc_id_to_grid_ids[arc_index] = snap_output['id']
            self._arc_id_to_comp_id[arc_index] = comp_id
            points = [item for sublist in snap_output['location'] for item in sublist]
            arc_index_to_grid_pts[arc_index] = points
            self.arc_id_to_grid_pts[arc.id] = points
            ls = LineString([(p[0], p[1]) for p in snap_output['location']])
            self.arc_id_to_node_string_length[arc_index] = ls.length

            if display_name not in self._arc_to_grid_points:
                self._arc_to_grid_points[display_name] = []
            self._arc_to_grid_points[display_name].append(points)

        # check if any structures are missing arcs
        for _, struct in self._structures.items():
            if struct['up'] == -1 or struct['down'] == -1:
                arc_index = struct['up'] if struct['up'] != -1 else struct['down']
                msg = f'The structure associated with arc id: {arc_index_to_arc_id[arc_index]} is missing an arc. ' \
                      f'This structure must be corrected or it will not be exported to the SRH2D input files.'
                self._logger.warning(msg)
        # check if any structures are associated with too many arcs
        for _, arc_list in structure_arcs.items():
            if len(arc_list) > 2:
                msg = f'A single structure is associated with more than 2 arcs. ' \
                      f'The boundary condition data must be fixed for the following arc ids: {arc_list}.'
                self._logger.error(msg)
        # check if arc directions may be wrong with any structures
        for _, struct in self._structures.items():
            arc_1 = struct['up']
            arc_2 = struct['down']
            pts_1 = None if arc_1 not in arc_index_to_grid_pts else arc_index_to_grid_pts[arc_1]
            pts_2 = None if arc_2 not in arc_index_to_grid_pts else arc_index_to_grid_pts[arc_2]
            if -1 not in [arc_1, arc_2] and pts_1 and pts_2:
                grid_pts_up = [(pts_1[i], pts_1[i + 1]) for i in range(0, len(pts_1), 3)]
                grid_pts_down = [(pts_2[i], pts_2[i + 1]) for i in range(0, len(pts_2), 3)]
                sum_dist_sq = 0
                for i in range(min(len(grid_pts_up), len(grid_pts_down))):
                    dx = grid_pts_up[i][0] - grid_pts_down[i][0]
                    dy = grid_pts_up[i][1] - grid_pts_down[i][1]
                    sum_dist_sq += dx * dx + dy * dy
                grid_pts_down.reverse()
                sum_dist_sq_2 = 0
                for i in range(min(len(grid_pts_up), len(grid_pts_down))):
                    dx = grid_pts_up[i][0] - grid_pts_down[i][0]
                    dy = grid_pts_up[i][1] - grid_pts_down[i][1]
                    sum_dist_sq_2 += dx * dx + dy * dy
                if sum_dist_sq_2 < sum_dist_sq:
                    msg = f'Check the structure associated with arc ids: {struct["up"]}, {struct["down"]}. ' \
                          f'Make sure that the arc directions are consistent.'
                    self._logger.warning(msg)

    def _create_component_folder_and_copy_display_options(self):
        """Creates the folder for the mapped bc component and copies the display options from the bc coverage."""
        if self.bc_mapped_comp_uuid is None:
            comp_uuid = str(uuid.uuid4())  # pragma: no cover
        else:
            comp_uuid = self.bc_mapped_comp_uuid
        self._logger.info('Creating component folder')
        bc_comp_path = os.path.dirname(self._bc_component_file)
        self._comp_path = os.path.join(os.path.dirname(bc_comp_path), comp_uuid)

        if os.path.exists(self._comp_path):
            shutil.rmtree(self._comp_path)  # pragma: no cover
        os.mkdir(self._comp_path)
        os.mkdir(os.path.join(self._comp_path, 'display_ids'))

        bc_comp_display_file = os.path.join(bc_comp_path, 'bc_display_options.json')
        comp_display_file = os.path.join(self._comp_path, 'bc_display_options.json')
        if os.path.isfile(bc_comp_display_file):
            shutil.copyfile(bc_comp_display_file, comp_display_file)
            categories = CategoryDisplayOptionList()  # Generates a random UUID key for the display list
            json_dict = read_display_options_from_json(comp_display_file)
            if self.bc_mapped_comp_display_uuid is None:
                json_dict['uuid'] = str(uuid.uuid4())  # pragma: no cover
            else:
                json_dict['uuid'] = self.bc_mapped_comp_display_uuid
            json_dict['comp_uuid'] = comp_uuid
            json_dict['is_ids'] = 0
            categories.from_dict(json_dict)
            categories.projection = {'wkt': self._grid_wkt}

            # Set all snapped arcs to be dashed and thick by default. Keep current color.
            for category in categories.categories:
                category.options.style = LineStyle.DASHEDLINE
                category.options.width = 4
                category.label_on = False

            write_display_options_to_json(comp_display_file, categories)
            self._comp_main_file = comp_display_file
        else:
            self._logger.info('Could not find bc_display_options.json file')  # pragma: no cover
