"""PolygonsFromArcsTool class."""
__copyright__ = "(C) Copyright Aquaveo 2024"
__license__ = "All rights reserved"

# 1. Standard Python modules
from enum import Enum
import math
import os
import pickle

# 2. Third party modules
from geopandas import GeoDataFrame
from osgeo import osr
import pandas as pd
from rtree import index
from shapely.geometry import LineString
from shapely.geometry import Polygon as shPoly

# 3. Aquaveo modules
from xms.api.dmi import Query, XmsEnvironment as XmEnv
from xms.gdal.utilities import gdal_utils as gu
from xms.tool.utilities.coverage_conversion import get_polygon_point_lists
from xms.tool_core import IoDirection, Tool

# 4. Local modules
import xms.tool_xms.algorithms.geom_funcs_with_shapely as gsh

ARG_INPUT_COVERAGE = 0
ARG_INPUT_GEOM_TYPE = 1
ARG_OUTPUT_COVERAGE = 2

STR_DISP_TYPE = 'Display'
STR_DIST_TO_P = 'Distance to Primary'
STR_LEN = 'Length'
STR_NUM_SEGS = '# Segments'
STR_NUM_CX = '# Connected Arcs'
STR_NUM_ARCS = '# Arcs'
STR_AREA = 'Area'
STR_PERIM = 'Perimeter'
STR_GEOM_IDX = 'Geometry List Index'
STR_ID = 'ID'


class DisplayEnum(Enum):
    """Enumeration for update triggers."""
    NORMAL = 0
    SIMPLIFY = 1
    EXCLUDE = 2


class CleanShapesTool(Tool):
    """Tool to clean geometry."""

    def __init__(self, name='Clean Shapes'):
        """Initializes the class."""
        super().__init__(name=name)
        self._is_mpbd_tool = name == 'Merge Polygons by Distance'

        # we want to set these even if we don't display them
        if not self._is_mpbd_tool:
            self.results_dialog_module = 'xms.tool_xms.gui.clean_shapes_dialog'
            self.results_dialog_class = 'CleanShapesDialog'
        # else:
        #     self.results_dialog_module = 'xms.tool_xms.gui.merge_polygons_by_distance_dialog'
        #     self.results_dialog_class = 'MergePolysByDistanceInvisibleDialog'

        self._rtree = None
        self._geometry = []
        self._geometry_locs = []
        self._geometry_locs_latlon = []
        self._geometry_for_props = []
        self._max_poly_id = 0
        self._max_arc_id = 0
        self._disjoint_points = []
        self._use_simplified = True
        self._simplify_tol = 100000
        self.query = None
        self.abort = False

        self._native_wkt = None
        self._geographic_wkt = None
        self._non_geographic_wkt = None

        self._node_locs_to_arc_idxs = dict()
        self._num_arcs = 0

        self._arc_display = {   # DataFrame of results data
            STR_DISP_TYPE: [],
            STR_ID: [],
            STR_LEN: [],
            STR_NUM_SEGS: [],
            STR_NUM_CX: [],
            STR_DIST_TO_P: [],
            STR_GEOM_IDX: []
        }
        self._poly_display = {  # DataFrame of results data
            STR_DISP_TYPE: [],
            STR_ID: [],
            STR_NUM_ARCS: [],
            STR_AREA: [],
            STR_PERIM: [],
            STR_DIST_TO_P: [],
            STR_GEOM_IDX: []
        }

        # 'a_' means it's an arc attribute, 'p_' means poly -- all attributes are stored as geometry idxs
        self._arc_to_poly_df = {
            'geom_idx': [],
            'feature_id': [],
            'a_outer_polys': [],  # polygons that use this arc as an outer boundary
            'a_inner_polys': [],  # polygons that use this arc as an internal boundary
            'p_outer_arcs': [],  # outer arcs that make up the perimeter of this polygon
            'p_inner_arcs': [],  # inner arcs that make up interior boundaries of this polygon
            'p_containing_poly': [],  # the polygon that contains this polygon as a hole
            'p_hole_polys': []  # polygons that are holes in this polygon
        }

        # properties used by other tools
        self._mpbd_tol = 5.0

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.coverage_argument(name='input_coverage', description='Input coverage'),
            self.string_argument(name='geom_type', description='Geometry type', value='Arcs',
                                 choices=['Arcs', 'Polygons']),
            self.coverage_argument(name='output_coverage', description='Output coverage', value='cleaned',
                                   io_direction=IoDirection.OUTPUT)
        ]
        return arguments

    def enable_arguments(self, arguments):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.
        """

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}

        return errors

    def _build_rtree(self, cov: GeoDataFrame):
        """Build the rtree for intersection operations.

        Args:
            cov (GeoDataFrame): The check coverage
        """
        # save the disjoint points for later
        self._disjoint_points = cov[cov['geometry_types'] == 'Point']

        polygons = cov[cov['geometry_types'] == 'Polygon']
        if self._geom_type == 'Polygons' and len(polygons) == 0:
            self.logger.info('The selected coverage has no polygons. Aborting.')
            self.abort = True
            return

        self.logger.info('Building rtree for operations...')
        is_geo = self._coord_sys == 'GEOGRAPHIC'
        arc_id_to_list_id = dict()
        self.logger.info('Processing arcs...')
        arcs = cov[cov['geometry_types'] == 'Arc']
        self._num_arcs = len(arcs)
        for arc in arcs.itertuples():
            arc_points = list(arc.geometry.coords)
            arc_locs = [(x, y, 0.0) for (x, y, _) in arc_points]
            new_geom = LineString(arc_locs)
            self._geometry.append(new_geom)
            self._geometry_locs.append([arc_locs])

            # if this is latlon, we don't want to use it for the length in the spreadsheet
            if is_geo and self._geom_type == 'Arcs':
                geom_for_props, self._non_geographic_wkt = convert_latlon_geom_to_utm(new_geom, self._geographic_wkt,
                                                                                      self._non_geographic_wkt)
            else:
                geom_for_props = new_geom

            self._geometry_for_props.append(geom_for_props)

            # map the arc id to its spot in the list so that we can keep track of arcs in polygons (by list index
            # because the id's will change as the coverage is edited)
            list_idx = len(self._geometry) - 1
            arc_id_to_list_id[arc.id] = list_idx

            if arc_locs[0] not in self._node_locs_to_arc_idxs.keys():
                self._node_locs_to_arc_idxs[arc_locs[0]] = set()
            self._node_locs_to_arc_idxs[arc_locs[0]].add(list_idx)

            if arc_locs[-1] not in self._node_locs_to_arc_idxs.keys():
                self._node_locs_to_arc_idxs[arc_locs[-1]] = set()
            self._node_locs_to_arc_idxs[arc_locs[-1]].add(list_idx)

            # store the info for the spreadsheet
            if self._geom_type == 'Arcs':
                # we don't display this stuff if we are doing polygons
                num_segments = int(len(arc_points) - 1)
                num_connections = 0
                d_type = 'Normal' if (not self._use_simplified or num_segments < self._simplify_tol) else 'Simplified'
                self._arc_display[STR_DISP_TYPE].append(d_type)
                self._arc_display[STR_ID].append(arc.id)
                self._arc_display[STR_LEN].append(geom_for_props.length)
                self._arc_display[STR_NUM_SEGS].append(num_segments)
                self._arc_display[STR_NUM_CX].append(num_connections)
                self._arc_display[STR_DIST_TO_P].append(-1.0)
                self._arc_display[STR_GEOM_IDX].append(list_idx)

            # arc database
            self._arc_to_poly_df['geom_idx'].append(list_idx)
            self._arc_to_poly_df['feature_id'].append(arc.id)
            self._arc_to_poly_df['a_outer_polys'].append([])
            self._arc_to_poly_df['a_inner_polys'].append([])
            self._arc_to_poly_df['p_outer_arcs'].append(None)
            self._arc_to_poly_df['p_inner_arcs'].append(None)
            self._arc_to_poly_df['p_containing_poly'].append(None)
            self._arc_to_poly_df['p_hole_polys'].append([])

            # update max arc id
            self._max_arc_id = max(self._max_arc_id, arc.id)

        if self._geom_type == 'Arcs':
            # fill out the spreadsheet info - we only care if we are doing arcs
            for arc_id in arc_id_to_list_id.keys():
                list_id = arc_id_to_list_id[arc_id]
                ends = gsh.endpt_locs_from_ls(self._geometry[list_id])
                if ends[0] == ends[-1]:
                    num = len(self._node_locs_to_arc_idxs[ends[0]])
                    if num == 1:
                        cx = -1
                    else:
                        cx = num - 1
                else:
                    cx = len(self._node_locs_to_arc_idxs[ends[0]]) + len(self._node_locs_to_arc_idxs[ends[1]]) - 2
                self._arc_display[STR_NUM_CX][list_id] = cx

        if self._geom_type == 'Polygons':
            self.logger.info('Processing polygons...')

            for poly in polygons.itertuples():
                lists = get_polygon_point_lists(poly)
                outer_locs = lists[0]
                inner_locs = None if len(lists) == 1 else lists[1:]

                new_geom = shPoly(outer_locs, inner_locs)
                if new_geom.exterior.is_ccw:
                    outer_locs.reverse()
                    new_geom = shPoly(outer_locs, inner_locs)
                if new_geom.is_valid is False:
                    self.logger.info(f'Polygon {poly.id} is invalid and will be ignored.')
                    continue
                self._geometry.append(new_geom)
                self._geometry_locs.append(lists)
                list_idx = len(self._geometry) - 1

                if is_geo:
                    geom_for_props, self._non_geographic_wkt = convert_latlon_geom_to_utm(new_geom,
                                                                                          self._geographic_wkt,
                                                                                          self._non_geographic_wkt)
                else:
                    geom_for_props = new_geom
                self._geometry_for_props.append(geom_for_props)

                # store the info for the spreadsheet
                self._poly_display[STR_DISP_TYPE].append(
                    'Normal' if len(outer_locs) < self._simplify_tol else 'Simplified')
                self._poly_display[STR_ID].append(poly.id)
                self._poly_display[STR_NUM_ARCS].append(len(poly.polygon_arc_ids))
                self._poly_display[STR_AREA].append(abs(geom_for_props.area))
                self._poly_display[STR_PERIM].append(geom_for_props.exterior.length)
                self._poly_display[STR_DIST_TO_P].append(-1.0)
                self._poly_display[STR_GEOM_IDX].append(list_idx)

                # update max poly id
                self._max_poly_id = max(self._max_poly_id, poly.id)

                # keep track of which outer arcs make up the polygon in case they are edited
                outer_arcs = []
                for arc_id in poly.polygon_arc_ids:
                    arc_list_idx = arc_id_to_list_id[arc_id]
                    self._arc_to_poly_df['a_outer_polys'][arc_list_idx].append(list_idx)
                    outer_arcs.append(arc_list_idx)

                # do the same for the interior arcs
                int_arcs = []
                holes = poly.interior_arc_ids
                if len(holes) > 0:
                    for hole in holes:
                        hole_arcs = []
                        for arc_id in hole:
                            arc_list_idx = arc_id_to_list_id[arc_id]
                            self._arc_to_poly_df['a_inner_polys'][arc_list_idx].append(list_idx)
                            hole_arcs.append(arc_list_idx)

                        int_arcs.append(hole_arcs)

                # polygon dataframe
                self._arc_to_poly_df['geom_idx'].append(list_idx)
                self._arc_to_poly_df['feature_id'].append(poly.id)
                self._arc_to_poly_df['a_outer_polys'].append(None)
                self._arc_to_poly_df['a_inner_polys'].append(None)
                self._arc_to_poly_df['p_outer_arcs'].append(outer_arcs)
                self._arc_to_poly_df['p_inner_arcs'].append(int_arcs)
                self._arc_to_poly_df['p_containing_poly'].append(None)
                self._arc_to_poly_df['p_hole_polys'].append([])

        if self._geom_type == 'Arcs':
            boxes = [g.bounds for g in self._geometry[: self._num_arcs]]
            geom_idxs = [idx for idx in range(self._num_arcs)]
        else:
            boxes = [g.bounds for g in self._geometry[self._num_arcs:]]
            geom_idxs = [idx + self._num_arcs for idx in range(len(self._geometry) - self._num_arcs)]
        self.logger.info('Done building rtree.')

        def generator_func():
            for j, b in enumerate(boxes):
                yield geom_idxs[j], b, b
        self._rtree = index.Index(generator_func())

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        self.process_data(arguments)

    def process_data(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        if self.query is None: self.query = Query(blocking=False)  # noqa E701
        projection = self.query.display_projection
        self._native_wkt = projection.well_known_text
        self._coord_sys = ''
        self._horiz_units = ''
        if gu.valid_wkt(self._native_wkt):
            if gu.is_geographic(self._native_wkt):
                self._coord_sys = 'GEOGRAPHIC'
            if gu.get_horiz_unit_from_wkt(self._native_wkt) == gu.UNITS_METERS:
                self._horiz_units = 'METERS'
        if self._coord_sys == 'GEOGRAPHIC':
            self._geographic_wkt = self._native_wkt
        else:
            ref = osr.SpatialReference()
            ref.ImportFromEPSG(4326)  # This is geographic GCS_WGS_1984
            self._geographic_wkt = ref.ExportToWkt()
            self._non_geographic_wkt = self._native_wkt

        self._arguments = arguments
        input_coverage = self.get_input_coverage(arguments[ARG_INPUT_COVERAGE].text_value)
        if not self._is_mpbd_tool:
            self._geom_type = arguments[ARG_INPUT_GEOM_TYPE].text_value
        else:
            self._geom_type = 'Polygons'
        self._output_coverage_name = arguments[ARG_OUTPUT_COVERAGE].text_value

        # create the RTree
        self._build_rtree(input_coverage)

        # send the info to the dialogs
        ss_df = pd.DataFrame(self._arc_display) if self._geom_type == 'Arcs' else pd.DataFrame(self._poly_display)
        with open(clean_shapes_ss_display_file(), 'wb') as f:
            pickle.dump({'df': ss_df}, f)

        arc_to_poly_df = pd.DataFrame(self._arc_to_poly_df)
        with open(clean_shapes_arc_poly_db_file(), 'wb') as f:
            pickle.dump({'df': arc_to_poly_df}, f)


def clean_shapes_ss_display_file():
    """Returns the name of the clean shapes DataFrame file for viewing clean shapes results."""
    temp_dir = XmEnv.xms_environ_process_temp_directory()
    return os.path.join(temp_dir, f'clean_shapes_ss_display_{os.getpid()}.pkl')


def clean_shapes_arc_poly_db_file():
    """Returns the name of the clean shapes DataFrame file for viewing clean shapes results."""
    temp_dir = XmEnv.xms_environ_process_temp_directory()
    return os.path.join(temp_dir, f'clean_shapes_arc_poly_db_file_{os.getpid()}.pkl')


def needs_axis_swap(srs: osr.SpatialReference) -> bool:
    """Returns True if the axis order is Y,X (e.g., Northing,Easting)."""
    wkt = srs.ExportToPrettyWkt()
    axis_lines = [line.strip() for line in wkt.splitlines() if line.strip().startswith("AXIS")]

    if len(axis_lines) >= 2:
        first_axis = axis_lines[0].lower()
        second_axis = axis_lines[1].lower()

        # EPSG axis order is often Y,X like ["northing", "easting"]
        if "northing" in first_axis or "latitude" in first_axis:
            if "easting" in second_axis or "longitude" in second_axis:
                return True  # It's Y,X → needs swapping
    return False  # Already traditional (X,Y)


def convert_latlon_geom_to_utm(geom, geo_wkt, non_geo_wkt):
    """Computes meters, decimal degrees conversion factor from latitude.

    Args:
        geom (Object): shapely object

    Returns:
        (Object): converted item
    """
    if geom.geom_type == 'LineString':
        converted_locs, non_geo_wkt = clean_convert_lat_lon_pts_to_utm(gsh.pts_from_ls(geom), geo_wkt, non_geo_wkt)
        return LineString(converted_locs), non_geo_wkt
    elif geom.geom_type == 'Polygon':
        all_locs = gsh.pts_from_shapely_poly(geom)
        converted_outer, non_geo_wkt = clean_convert_lat_lon_pts_to_utm(all_locs[0], geo_wkt, non_geo_wkt)
        converted_holes = []
        for idx in range(1, len(all_locs)):
            converted_inner, non_geo_wkt = clean_convert_lat_lon_pts_to_utm(all_locs[idx], geo_wkt, non_geo_wkt)
            converted_holes.append(converted_inner)
        return shPoly(converted_outer, converted_holes), non_geo_wkt


def clean_convert_lat_lon_pts_to_utm(pts, geo_wkt, non_geo_wkt):
    """Converts a list of lat/lon points to UTM.

    Args:
        pts (list): A list of lat/lon points.
        geo_wkt (string): geographic wkt
        non_geo_wkt (string): non-geographic wkt to convert to

    Returns:
        (list): A list of UTM points.
        (osr.SpatialReference): The spatial reference of the UTM points.
    """
    # get the bounding box of the points
    pt_x = [pt[0] for pt in pts]
    pt_y = [pt[1] for pt in pts]
    lon = (min(pt_x) + max(pt_x)) / 2.0
    lat = (min(pt_y) + max(pt_y)) / 2.0

    if non_geo_wkt is None:
        spatial_ref = osr.SpatialReference()
        spatial_ref.ImportFromWkt(geo_wkt)

        utm_spatial_ref = spatial_ref.CloneGeogCS()
        zone = int((lon + 180.0) / 6) + 1
        b_north = True if lat > 0 else False
        utm_spatial_ref.SetUTM(zone, b_north)
        non_geo_wkt = utm_spatial_ref.ExportToWkt()

    # convert the points to utm
    trans = gu.get_coordinate_transformation(geo_wkt, non_geo_wkt)
    out_pts = trans.TransformPoints(pts)

    return out_pts, non_geo_wkt


def meters_to_decimal_degrees(length_meters, latitude):
    """Convert meters to decimal degrees based on the latitude.

    Args:
        length_meters (:obj:`float`): length in meters
        latitude (:obj:`float`): latitude in decimal degrees

    Returns:
        (:obj:`float`): length in decimal degrees
    """
    return length_meters / factor_from_latitude(latitude)


def feet_to_decimal_degrees(length_feet, latitude):
    """Convert feet to decimal degrees based on the latitude.

    Args:
        length_feet (:obj:`float`): length in feet
        latitude (:obj:`float`): latitude in decimal degrees

    Returns:
        (:obj:`float`): length in decimal degrees
    """
    length_meter = length_feet / FEET_PER_METER
    return meters_to_decimal_degrees(length_meter, latitude)


FEET_PER_METER = 3.28083989501


def factor_from_latitude(latitude):
    """Computes meters, decimal degrees conversion factor from latitude.

    Args:
        latitude (:obj:`float`): the latitude

    Returns:
        (:obj:`float`): conversion factor
    """
    return 111.32 * 1000 * math.cos(latitude * (math.pi / 180))
