"""A dialog for displaying feature map lines."""
__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import math

# 2. Third party modules
import folium
import folium.plugins

# 3. Aquaveo modules
from xms.data_objects.parameters import FilterLocation
from xms.gdal.utilities.gdal_utils import get_coordinate_transformation, wkt_from_epsg

# 4. Local modules


def _factor_from_latitude(latitude):
    """Computes meters, decimal degrees conversion factor from latitude.

    Args:
        latitude (:obj:`float`): the latitude

    Returns:
        (:obj:`float`): conversion factor
    """
    return 111.32 * 1000 * math.cos(latitude * (math.pi / 180))


def meters_to_decimal_degrees(length_meters, latitude):
    """Convert meters to decimal degrees based on the latitude.

    Args:
        length_meters (:obj:`float`): length in meters
        latitude (:obj:`float`): latitude in decimal degrees

    Returns:
        (:obj:`float`): length in decimal degrees
    """
    return length_meters / _factor_from_latitude(latitude)


class PreviewHtml:
    """Class for generating the HTML preview of the 3d structure lines and mesh."""
    def __init__(self, dlg_data):
        """Initializes the class.

        Args:
            parent (:obj:`QWidget`): Parent dialog
            dlg_data (:obj:`dict`): Data for the dialog
        """
        self._data = dlg_data
        self._polyline_data = None
        # Extents of all geometric data - [(min_lat, min_lon), (max_lat, max_lon)]
        self._global_extents = []
        # Projection info that we have to pass in because it does not reliably serialize to the Coverage H5 dump file.
        self._is_geographic = self._data['projection'].coordinate_system == 'GEOGRAPHIC'
        self._is_local = self._data['projection'].coordinate_system == 'None'
        self._transform = None
        if not self._is_geographic and not self._is_local:
            # Set up a transformation from input non-geographic to WGS 84
            src_wkt = self._data['projection'].well_known_text
            _, tgt_wkt = wkt_from_epsg(4326)
            self._transform = get_coordinate_transformation(src_wkt, tgt_wkt)
        self._folium_map = folium.Map(tiles='')
        self._draw_coverage()
        self._draw_mesh()

    def _draw_coverage(self):
        """Draws the coverage on the folium plot."""
        arc_data = self._data.get('arc_data', {})
        cov = arc_data.get('coverage', None)
        if cov is None:
            return
        is_culvert = self._data.get('struct_type', 'Bridge') == 'Culvert'
        fg = folium.map.FeatureGroup('Arcs').add_to(self._folium_map)
        arcs = cov.arcs
        arc_types = self._data.get('tool_arc_data', None)
        arc_df = self._data.get('arc_df', None)
        if arc_df is not None:
            arc_types = {aa['Arc ID']: aa['Type'] for aa in arc_df.to_dict('records')}
            for k, v in arc_types.items():
                if 'Pier group' in v or 'Wall pier' in v:
                    arc_types[k] = 'Pier'
        attr = {'font-weight': 'bold', 'font-size': '12px'}
        type_to_color = {
            'Bridge': 'blue',
            'Embankment': 'blue',
            'Pier': 'red',
            'Culvert': 'green',
            'Abutment': 'orange',
            'Unassigned': 'black'
        }
        self._polyline_data = {'color': [], 'pts': [], 'label': []}
        for arc in arcs:
            arc_type = ''
            if arc.id in arc_types:
                arc_type = arc_types[arc.id]
                if is_culvert:
                    if arc_type == 'Bridge':
                        arc_type = 'Embankment'
                    elif arc_type == 'Pier':
                        arc_type = 'Culvert'
            color = type_to_color.get(arc_type, 'black')
            arc_points = arc.get_points(FilterLocation.PT_LOC_ALL)
            pts = [(point.x, point.y) for point in arc_points]
            pts = self._transform_points(pts)
            pl = folium.vector_layers.PolyLine(locations=pts, weight=3, color=color).add_to(fg)
            label = f'ID: {arc.id} - {arc_type}'
            folium.plugins.PolyLineTextPath(polyline=pl, text=label, offset=10, attributes=attr).add_to(fg)
            # for testing
            self._polyline_data['color'].append(color)
            self._polyline_data['pts'].append(pts)
            self._polyline_data['label'].append(label)

    def _draw_mesh(self):
        """Draws the mesh on the folium plot."""
        ug = self._data.get('ugrid', None)
        if ug is not None:
            self._add_ugrid_edges_to_plot()
        self._folium_map.fit_bounds(self._global_extents)
        folium.map.LayerControl().add_to(self._folium_map)
        self._folium_map.save(self._data['url'])

    def _add_ugrid_edges_to_plot(self):
        """Add the edges of the ugrid to the folium plot."""
        ugrid = self._data['ugrid'].ugrid
        # Get the bounds of the grid points passed in
        grid_points = ugrid.locations
        grid_points = self._transform_points(grid_points)
        # Hash the edges
        cell_count = ugrid.cell_count
        edge_set = set()  # Use this set to remove duplicate edges
        for cell_idx in range(cell_count):
            cell_edges = ugrid.get_cell_edges(cell_idx)
            for edge in cell_edges:
                if edge[0] > edge[1]:
                    edge_set.add((edge[1], edge[0]))
                else:
                    edge_set.add((edge[0], edge[1]))

        fg = folium.map.FeatureGroup('Bridge mesh').add_to(self._folium_map)
        ml_list = []
        for edge in edge_set:
            pts = [
                (grid_points[edge[0]][0], grid_points[edge[0]][1]), (grid_points[edge[1]][0], grid_points[edge[1]][1])
            ]
            ml_list.append([pts[0], pts[1]])
        self.grid_edges = ml_list
        folium.vector_layers.PolyLine(locations=ml_list, weight=1, color='black').add_to(fg)

    def _transform_points(self, points):
        """Transforms the points from their current CRS (coordinate reference system) to the new one.

        See https://gis.stackexchange.com/questions/226200/how-to-call-gdaltransform-with-its-input

        Args:
            points (:obj:`list`): List of points.

        Returns:
            new_points (:obj:`list`): The transformed points.
        """
        new_points = []
        # we do not handle the case of working in geographic coords (lat, lon)
        # if self._transform is None and not self._is_local:
        #     new_points = points
        # else:
        # Transform each point
        for point in points:
            if self._transform:
                coords = self._transform.TransformPoint(point[0], point[1])
                new_points.append([coords[1], coords[0]])
            else:
                y = meters_to_decimal_degrees(point[0], 0.0)
                x = meters_to_decimal_degrees(point[1], 0.0)
                new_points.append([x, y])

        ext = self._global_extents
        xx = [p[0] for p in new_points]
        yy = [p[1] for p in new_points]
        if ext:
            xx = xx + [ext[0][0], ext[1][0]]
            yy = yy + [ext[0][1], ext[1][1]]
        self._global_extents = [[min(xx), min(yy)], [max(xx), max(yy)]]

        return new_points
