"""TrimCoverage class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import logging

# 2. Third party modules
from pyproj.enums import WktVersion
from shapely.geometry import LineString, MultiPolygon, Polygon

# 3. Aquaveo modules
from xms.gdal.vectors import VectorOutput
from xms.tool_core.coverage_builder import CoverageBuilder

# 4. Local modules


def _create_trimmed_shapefile(trimmed_arcs, filename, wkt):  # pragma no cover (runs in Linux tests)
    """Create a shapefile with the trimmed arcs."""
    vo = VectorOutput()
    vo.initialize_file(filename, wkt)
    for arc in trimmed_arcs:
        if not arc.is_empty:
            vo.write_arc(arc.coords)
    # Close the handles associated with the arc shapefile
    vo = None


class TrimCoverage:
    """Class to trim a coverage by polygons passed in."""

    def __init__(self, polygon_data, hole_data, arc_data, buffer_distance, trim_out, new_cov_name, logger=None,
                 num_pts_to_smooth=2500):
        """Initializes the class.

        Args:
            polygon_data (list(dict)): polygon points and attributes from the source coverage
            hole_data (list(list(dict))): polygon hole points and attributes from the source coverage
            arc_data (list(dict)): arc points and attributes from the coverage being trimmed
            buffer_distance (float): distance to buffer for trimming
            trim_out (int): flag for trimming inside the polygon or outside
            new_cov_name (str): name of the new coverage
            logger (logging.logger): logger
            num_pts_to_smooth (int): threshold number of point to smooth the polygon_data
        """
        self.logger = logger
        if self.logger is None:
            self.logger = logging.getLogger('xms.tool')
        self._polygon_data_list = polygon_data
        self._hole_data_list = hole_data  # Make sure the size of this list is same as polygon_data
        self._polygon_data = None
        self._poly_arcs = []
        self._arc_data = arc_data
        self._cov_geom = None if len(polygon_data) < 1 else polygon_data[0].get('cov_geom', None)
        self.new_cov = None
        self._buffer_distance = buffer_distance
        self._trim_out = trim_out
        self._new_cov_name = new_cov_name
        self._num_pts_to_smooth = num_pts_to_smooth
        self._trimmed_arcs = None

    def generate_coverage(self, cov_wkt=None):
        """Trims arcs from a polygon coverage and generates a coverage."""
        self._trimmed_arcs = self.trimmed_arcs
        self.logger.info('Creating new coverage.')
        self._create_trimmed_coverage(self._trimmed_arcs, cov_wkt)
        return self.new_cov

    def generate_shapefile(self, filename, wkt):  # pragma no cover (runs in Linux tests)
        """Trims arcs from a polygon coverage and generates a shapefile."""
        self._trimmed_arcs = self.trimmed_arcs
        self.logger.info('Creating new shapefile.')
        _create_trimmed_shapefile(self._trimmed_arcs, filename, wkt)

    @property
    def trimmed_arcs(self):
        """Returns the trimmed arcs as shapely LineStrings."""
        if self._trimmed_arcs is None:
            self._trimmed_arcs = self._trim_arcs()
        return self._trimmed_arcs

    def _trim_arcs(self):
        """Trims arcs from a polygon coverage."""
        poly_list = []
        for poly, holes in zip(self._polygon_data_list, self._hole_data_list):
            hole_pts = []
            for cur_hole in holes:
                self._smooth_poly_points(cur_hole)
                hole_pts.append(cur_hole['poly_pts'])
            self._smooth_poly_points(poly)
            cur_poly = Polygon(poly['poly_pts'], hole_pts)
            poly_list.append(cur_poly)
        multi_poly = MultiPolygon(poly_list)
        self.logger.info('Computing polygon buffer.')
        if self._trim_out:
            poly_buffer = multi_poly.buffer(self._buffer_distance)
        else:
            poly_buffer = multi_poly.buffer(-self._buffer_distance)
        the_arcs = []
        for arc in self._arc_data:
            line = arc['arc_pts']
            if len(line) > 1:
                cur_arc = LineString(arc['arc_pts'])
            the_arcs.append(cur_arc)
        self.logger.info('Trimming coverage to buffered polygon.')
        trimmed_arcs = []
        for arc in the_arcs:
            if self._trim_out:
                trimmed_arc = arc.difference(poly_buffer)
            else:
                trimmed_arc = arc.intersection(poly_buffer)
            if trimmed_arc.geom_type == 'MultiLineString':
                for line_string in trimmed_arc.geoms:
                    if not line_string.is_empty:
                        trimmed_arcs.append(line_string)
            elif trimmed_arc.geom_type == 'GeometryCollection':
                for line_string in trimmed_arc.geoms:
                    if arc.geom_type == 'LineString':
                        if not line_string.is_empty:
                            trimmed_arcs.append(line_string)
            elif trimmed_arc.geom_type == 'LineString':
                if not trimmed_arc.is_empty:
                    trimmed_arcs.append(trimmed_arc)
        return trimmed_arcs

    def _smooth_poly_points(self, poly):
        """Smooth point locations that are connected to edges smaller than the buffer distance.

        Args:
            poly (dict): polygon information
        """
        poly_pts = poly['poly_pts']
        if len(poly_pts) < self._num_pts_to_smooth:
            return
        buffer_sq = self._buffer_distance * self._buffer_distance
        # do the smooth 3 times
        new_pts1 = poly_pts
        for iter in range(3):
            self.logger.info(f'Smoothing polygon boundary - iteration {iter}')
            new_pts = [new_pts1[0]]
            dx = new_pts1[1][0] - new_pts1[0][0]
            dy = new_pts1[1][1] - new_pts1[0][1]
            dist_sq_2 = (dx * dx) + (dy * dy)
            for i in range(1, len(poly_pts) - 1):
                dist_sq_1 = dist_sq_2
                pt = new_pts1[i - 1:i + 2]
                dx = pt[2][0] - pt[1][0]
                dy = pt[2][1] - pt[2][1]
                dist_sq_2 = (dx * dx) + (dy * dy)
                if dist_sq_1 < buffer_sq and dist_sq_2 < buffer_sq:
                    x = (pt[0][0] + pt[1][0] + pt[2][0]) / 3
                    y = (pt[0][1] + pt[1][1] + pt[2][1]) / 3
                    new_pts.append([x, y, pt[1][2]])
                else:
                    new_pts.append(pt[1])
            new_pts.append(new_pts1[-1])
            new_pts1 = new_pts
        poly['poly_pts'] = new_pts

    def _create_trimmed_coverage(self, trimmed_arcs, cov_wkt):
        """Create a coverage with the trimmed arcs."""
        wkt = ''
        if self._cov_geom is not None:
            if self._cov_geom.crs is not None:
                wkt = self._cov_geom.crs.to_wkt(version=WktVersion.WKT1_GDAL)
        if not wkt and cov_wkt is not None:
            wkt = cov_wkt
        coverage_builder = CoverageBuilder(wkt, self._new_cov_name)
        for arc in trimmed_arcs:
            if not arc.is_empty:
                arc_pts = []
                z_value = 0.0 if len(arc.coords[0]) < 3 else arc.coords[0][2]
                arc_pts.append((arc.coords[0][0], arc.coords[0][1], z_value))
                idx1 = len(arc.coords) - 1
                for i in range(1, idx1):
                    z_value = 0.0 if len(arc.coords[i]) < 3 else arc.coords[i][2]
                    arc_pts.append((arc.coords[i][0], arc.coords[i][1], z_value))
                z_value = 0.0 if len(arc.coords[-1]) < 3 else arc.coords[-1][2]
                arc_pts.append((arc.coords[-1][0], arc.coords[-1][1], z_value))
                coverage_builder.add_arc(arc_pts)
        self.new_cov = coverage_builder.build_coverage()
