"""Class for reading rasters using GDAL."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules
from osgeo import gdal, ogr

# 3. Aquaveo modules

# 4. Local modules
from xms.gdal.utilities import gdal_utils as gu


def _create_att_fields(layer, atts):
    if atts is not None:
        for k, v in atts.items():
            value_type = None
            if v:
                if type(v[0]) is str:
                    value_type = ogr.OFTString
                elif type(v[0]) is int or type(v[0]) is bool:
                    value_type = ogr.OFTInteger
                elif type(v[0]) is float:
                    value_type = ogr.OFTReal
            if value_type is not None:
                field = ogr.FieldDefn(k, value_type)
                layer.CreateField(field)


def _set_atts(atts, feature, index):
    if atts is not None:
        for k, v in atts.items():
            feature.SetField(k, v[index])


class VectorOutput:
    """Class for writing vector data using GDAL."""

    def __init__(self):
        """Initializes the class."""
        gdal.SetConfigOption('GTIFF_REPORT_COMPD_CS', 'TRUE')
        self._vec_ds = None
        self._vec_lyr = None
        self._active_sr = None
        self._from_sr = None
        self._cur_id = 1

    def initialize_file(self, filename, wkt, file_format='ESRI Shapefile', from_wkt=''):
        """Initializes the class.

        Args:
            filename (str): The full path and filename of the vector file to write.
            wkt (str): The WKT of the coordinate system to be written with this vector file.
            file_format (str): The "Short name" of the file format listed on this site:
             https://gdal.org/drivers/vector/index.html (such as "ESRI Shapefile")
            from_wkt (str): The WKT to convert each vector object to (leave as default for no conversion)
        """
        self._vec_ds = ogr.GetDriverByName(file_format).CreateDataSource(filename)
        if gu.valid_wkt(wkt):
            self._active_sr = gu.wkt_to_sr(wkt)
        if from_wkt and gu.valid_wkt(from_wkt):
            self._from_sr = gu.wkt_to_sr(from_wkt)

    def write_polygon(self, points, holes, atts=None):
        """Writes a polygon.

        Args:
            points (list): Points on the outside of the polygon.
            holes (list): List of polygons making up any polygons inside the outside polygon.
            atts (dict): Dict of atts with their names as the index and the value.
        """
        if self._vec_lyr is None:
            self._vec_lyr = self._vec_ds.CreateLayer('active_polygons', srs=self._active_sr, geom_type=ogr.wkbPolygon)
            id_field = ogr.FieldDefn("id", ogr.OFTInteger)
            self._vec_lyr.CreateField(id_field)
            _create_att_fields(self._vec_lyr, atts)
        ring = ogr.Geometry(ogr.wkbLinearRing)
        for point in points:
            if len(point) > 2:
                ring.AddPoint(point[0], point[1], z=point[2])
            else:
                ring.AddPoint(point[0], point[1])
        polygon = ogr.Geometry(ogr.wkbPolygon)
        polygon.AddGeometry(ring)
        for hole in holes:
            inner_ring = ogr.Geometry(ogr.wkbLinearRing)
            for point in hole:
                if len(point) > 2:
                    inner_ring.AddPoint(point[0], point[1], z=point[2])
                else:
                    inner_ring.AddPoint(point[0], point[1])
            polygon.AddGeometry(inner_ring)
        if self._from_sr is not None:
            polygon.AssignSpatialReference(self._from_sr)
            if self._active_sr is not None:
                polygon.TransformTo(self._active_sr)

        feature = ogr.Feature(feature_def=self._vec_lyr.GetLayerDefn())
        poly_geom = ogr.CreateGeometryFromWkt(polygon.ExportToWkt())
        feature.SetGeometryDirectly(poly_geom)
        feature.SetField("id", self._cur_id)
        _set_atts(atts, feature, self._cur_id - 1)
        self._cur_id += 1
        self._vec_lyr.CreateFeature(feature)

    def _write_arc(self, points, geom_type, atts):
        """Writes an arc.

        Args:
            points (list): Points on the arc.
            geom_type (int): The geometry type to write (from the ogr enum).
            atts (dict): Dict of atts with their names as the index and the value.
        """
        if self._vec_lyr is None:
            self._vec_lyr = self._vec_ds.CreateLayer('active_arcs', srs=self._active_sr, geom_type=geom_type)
            id_field = ogr.FieldDefn("id", ogr.OFTInteger)
            self._vec_lyr.CreateField(id_field)
            _create_att_fields(self._vec_lyr, atts)
        arc = ogr.Geometry(geom_type)
        for point in points:
            if len(point) > 2:
                arc.AddPoint(point[0], point[1], z=point[2])
            else:
                arc.AddPoint(point[0], point[1])
        if self._from_sr is not None:
            arc.AssignSpatialReference(self._from_sr)
            if self._active_sr is not None:
                arc.TransformTo(self._active_sr)

        feature = ogr.Feature(feature_def=self._vec_lyr.GetLayerDefn())
        arc_geom = ogr.CreateGeometryFromWkt(arc.ExportToWkt())
        feature.SetGeometryDirectly(arc_geom)
        feature.SetField("id", self._cur_id)
        _set_atts(atts, feature, self._cur_id - 1)
        self._cur_id += 1
        self._vec_lyr.CreateFeature(feature)

    def write_arc(self, points, atts=None):
        """Writes an arc.

        Args:
            points (list): Points on the arc.
            atts (dict): Dict of atts with their names as the index and the value.
        """
        self._write_arc(points, ogr.wkbLineStringZM, atts)

    def write_point(self, pt, atts=None):
        """Writes a point.

        Args:
            pt (list): The point to write.
            atts (dict): Dict of atts with their names as the index and the value.
        """
        if self._vec_lyr is None:
            self._vec_lyr = self._vec_ds.CreateLayer('active_points', srs=self._active_sr, geom_type=ogr.wkbPoint25D)
            id_field = ogr.FieldDefn("id", ogr.OFTInteger)
            self._vec_lyr.CreateField(id_field)
            _create_att_fields(self._vec_lyr, atts)
        point = ogr.Geometry(ogr.wkbPoint25D)
        if len(pt) > 2:
            point.AddPoint(pt[0], pt[1], z=pt[2])
        else:
            point.AddPoint(pt[0], pt[1])
        if self._from_sr is not None:
            point.AssignSpatialReference(self._from_sr)
            if self._active_sr is not None:
                point.TransformTo(self._active_sr)

        feature = ogr.Feature(feature_def=self._vec_lyr.GetLayerDefn())
        point_geom = ogr.CreateGeometryFromWkt(point.ExportToWkt())
        feature.SetGeometryDirectly(point_geom)
        feature.SetField("id", self._cur_id)
        _set_atts(atts, feature, self._cur_id - 1)
        self._cur_id += 1
        self._vec_lyr.CreateFeature(feature)

    @property
    def ogr_layer(self):
        """Returns the OGR.Layer for this vector data."""
        return self._vec_lyr
