"""Class for reading vectors using GDAL."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import sys

# 2. Third party modules
from osgeo import gdal, ogr
from osgeo.ogr import FieldDefn

# 3. Aquaveo modules

# 4. Local modules


def _read_line(layer_def, feature):
    line_points = []
    geom = feature.GetGeometryRef()
    if geom is not None:
        z_val = 0.0
        elev_field_strings = 'elev\tvaldco\tcontour\twsel_reg\theight'.casefold()
        for i in range(layer_def.GetFieldCount()):
            field_defn = layer_def.GetFieldDefn(i)
            if field_defn.name.casefold() in elev_field_strings:
                z_val = float(feature.GetField(i))
                break
        for i in range(geom.GetPointCount()):
            point = geom.GetPoint(i)
            geom_type = geom.GetGeometryType()
            if geom_type == ogr.wkbLineString or geom_type == ogr.wkbLineStringM or \
                    geom_type == ogr.wkbMultiLineString or geom_type == ogr.wkbMultiLineStringM or len(point) < 3:
                line_points.append((point[0], point[1], z_val))
            else:
                line_points.append((point[0], point[1], point[2]))
    return line_points


def _read_point(layer_def, feature):
    point = (0.0, 0.0, 0.0)
    geom = feature.GetGeometryRef()
    if geom is not None:
        z_val = 0.0
        elev_field_strings = 'elev\tvaldco\tcontour\twsel_reg\theight'.casefold()
        for i in range(layer_def.GetFieldCount()):
            field_defn = layer_def.GetFieldDefn(i)
            if field_defn.name.casefold() in elev_field_strings:
                z_val = float(feature.GetField(i))
                break
        geom_type = geom.GetGeometryType()
        if geom_type == ogr.wkbPoint or geom_type == ogr.wkbPointM:
            point = (geom.GetX(), geom.GetY(), z_val)
        else:
            point = (geom.GetX(), geom.GetY(), geom.GetZ())
    return point


def _read_polygon(layer_def, feature):
    poly_points = []
    geom = feature.GetGeometryRef()
    if geom is not None:
        z_val = 0.0
        elev_field_strings = 'elev\tvaldco\tcontour\twsel_reg\theight'.casefold()
        for i in range(layer_def.GetFieldCount()):
            field_defn = layer_def.GetFieldDefn(i)
            if field_defn.name.casefold() in elev_field_strings:
                z_val = float(feature.GetField(i))
                break
        geom_type = geom.GetGeometryType()
        if geom_type == ogr.wkbMultiPolygon or geom_type == ogr.wkbMultiPolygonM or \
                geom_type == ogr.wkbMultiPolygonZM or geom_type == ogr.wkbMultiPolygon25D:
            num_geom = geom.GetGeometryCount()
            for i in range(num_geom):
                poly_geom = geom.GetGeometryRef(i)
                num_rings = poly_geom.GetGeometryCount()
                for j in range(num_rings):
                    poly_points.append(_read_linear_ring(poly_geom.GetGeometryRef(j), z_val))
        else:
            num_rings = geom.GetGeometryCount()
            for i in range(num_rings):
                poly_points.append(_read_linear_ring(geom.GetGeometryRef(i), z_val))
    return poly_points


def _read_linear_ring(ring, z_val):
    pts = []
    num_pts = ring.GetPointCount()
    for i in range(num_pts):
        point = ring.GetPoint(i)
        pts.append((point[0], point[1], z_val))
    return pts


def get_point_features(vec_ds):
    """Reads the point features from the vector dataset.

    Returns:
        (list): List of tuples containing points in the vector data file.
    """
    points = []
    num_layers = vec_ds.GetLayerCount()
    for layer_index in range(num_layers):
        layer = vec_ds.GetLayerByIndex(layer_index)
        layer.ResetReading()
        layer_def = layer.GetLayerDefn()
        for feature in layer:
            points.append(_read_point(layer_def, feature))
    return points


def get_line_features(vec_ds):
    """Reads the line features from the vector dataset.

    Returns:
        (list): 2D list of tuples containing points making up all the lines in the vector data file.
    """
    lines = []
    num_layers = vec_ds.GetLayerCount()
    for layer_index in range(num_layers):
        layer = vec_ds.GetLayerByIndex(layer_index)
        layer.ResetReading()
        layer_def = layer.GetLayerDefn()
        for feature in layer:
            line_points = _read_line(layer_def, feature)
            if len(line_points) > 1:
                lines.append(line_points)
    return lines


def get_poly_features(vec_ds):
    """Reads the polygon features from the vector dataset.

    Returns:
        (list): 3D list of tuples containing points making up all the polygons in the vector data file.
    """
    polys = []
    num_layers = vec_ds.GetLayerCount()
    for layer_index in range(num_layers):
        layer = vec_ds.GetLayerByIndex(layer_index)
        layer.ResetReading()
        layer_def = layer.GetLayerDefn()
        for feature in layer:
            polys.append(_read_polygon(layer_def, feature))
    return polys


class VectorInput:
    """Class for reading vector data using GDAL."""

    # Returns from layer_type
    wkbUnknown = ogr.wkbUnknown  # noqa: N815
    wkbPoint = ogr.wkbPoint  # noqa: N815
    wkbLineString = ogr.wkbLineString  # noqa: N815
    wkbPolygon = ogr.wkbPolygon  # noqa: N815
    wkbMultiPoint = ogr.wkbMultiPoint  # noqa: N815
    wkbMultiLineString = ogr.wkbMultiLineString  # noqa: N815
    wkbMultiPolygon = ogr.wkbMultiPolygon  # noqa: N815
    wkbGeometryCollection = ogr.wkbGeometryCollection  # noqa: N815
    wkbCircularString = ogr.wkbCircularString  # noqa: N815
    wkbCompoundCurve = ogr.wkbCompoundCurve  # noqa: N815
    wkbCurvePolygon = ogr.wkbCurvePolygon  # noqa: N815
    wkbMultiCurve = ogr.wkbMultiCurve  # noqa: N815
    wkbMultiSurface = ogr.wkbMultiSurface  # noqa: N815
    wkbCurve = ogr.wkbCurve  # noqa: N815
    wkbSurface = ogr.wkbSurface  # noqa: N815
    wkbPolyhedralSurface = ogr.wkbPolyhedralSurface  # noqa: N815
    wkbTIN = ogr.wkbTIN  # noqa: N815
    wkbTriangle = ogr.wkbTriangle  # noqa: N815
    wkbNone = ogr.wkbNone  # noqa: N815
    wkbLinearRing = ogr.wkbLinearRing  # noqa: N815
    wkbCircularStringZ = ogr.wkbCircularStringZ  # noqa: N815
    wkbCompoundCurveZ = ogr.wkbCompoundCurveZ  # noqa: N815
    wkbCurvePolygonZ = ogr.wkbCurvePolygonZ  # noqa: N815
    wkbMultiCurveZ = ogr.wkbMultiCurveZ  # noqa: N815
    wkbMultiSurfaceZ = ogr.wkbMultiSurfaceZ  # noqa: N815
    wkbCurveZ = ogr.wkbCurveZ  # noqa: N815
    wkbSurfaceZ = ogr.wkbSurfaceZ  # noqa: N815
    wkbPolyhedralSurfaceZ = ogr.wkbPolyhedralSurfaceZ  # noqa: N815
    wkbTINZ = ogr.wkbTINZ  # noqa: N815
    wkbTriangleZ = ogr.wkbTriangleZ  # noqa: N815
    wkbPointM = ogr.wkbPointM  # noqa: N815
    wkbLineStringM = ogr.wkbLineStringM  # noqa: N815
    wkbPolygonM = ogr.wkbPolygonM  # noqa: N815
    wkbMultiPointM = ogr.wkbMultiPointM  # noqa: N815
    wkbMultiLineStringM = ogr.wkbMultiLineStringM  # noqa: N815
    wkbMultiPolygonM = ogr.wkbMultiPolygonM  # noqa: N815
    wkbGeometryCollectionM = ogr.wkbGeometryCollectionM  # noqa: N815
    wkbCircularStringM = ogr.wkbCircularStringM  # noqa: N815
    wkbCompoundCurveM = ogr.wkbCompoundCurveM  # noqa: N815
    wkbCurvePolygonM = ogr.wkbCurvePolygonM  # noqa: N815
    wkbMultiCurveM = ogr.wkbMultiCurveM  # noqa: N815
    wkbMultiSurfaceM = ogr.wkbMultiSurfaceM  # noqa: N815
    wkbCurveM = ogr.wkbCurveM  # noqa: N815
    wkbSurfaceM = ogr.wkbSurfaceM  # noqa: N815
    wkbPolyhedralSurfaceM = ogr.wkbPolyhedralSurfaceM  # noqa: N815
    wkbTINM = ogr.wkbTINM  # noqa: N815
    wkbTriangleM = ogr.wkbTriangleM  # noqa: N815
    wkbPointZM = ogr.wkbPointZM  # noqa: N815
    wkbLineStringZM = ogr.wkbLineStringZM  # noqa: N815
    wkbPolygonZM = ogr.wkbPolygonZM  # noqa: N815
    wkbMultiPointZM = ogr.wkbMultiPointZM  # noqa: N815
    wkbMultiLineStringZM = ogr.wkbMultiLineStringZM  # noqa: N815
    wkbMultiPolygonZM = ogr.wkbMultiPolygonZM  # noqa: N815
    wkbGeometryCollectionZM = ogr.wkbGeometryCollectionZM  # noqa: N815
    wkbCircularStringZM = ogr.wkbCircularStringZM  # noqa: N815
    wkbCompoundCurveZM = ogr.wkbCompoundCurveZM  # noqa: N815
    wkbCurvePolygonZM = ogr.wkbCurvePolygonZM  # noqa: N815
    wkbMultiCurveZM = ogr.wkbMultiCurveZM  # noqa: N815
    wkbMultiSurfaceZM = ogr.wkbMultiSurfaceZM  # noqa: N815
    wkbCurveZM = ogr.wkbCurveZM  # noqa: N815
    wkbSurfaceZM = ogr.wkbSurfaceZM  # noqa: N815
    wkbPolyhedralSurfaceZM = ogr.wkbPolyhedralSurfaceZM  # noqa: N815
    wkbTINZM = ogr.wkbTINZM  # noqa: N815
    wkbTriangleZM = ogr.wkbTriangleZM  # noqa: N815
    wkbPoint25D = ogr.wkbPoint25D  # noqa: N815
    wkbLineString25D = ogr.wkbLineString25D  # noqa: N815
    wkbPolygon25D = ogr.wkbPolygon25D  # noqa: N815
    wkbMultiPoint25D = ogr.wkbMultiPoint25D  # noqa: N815
    wkbMultiLineString25D = ogr.wkbMultiLineString25D  # noqa: N815
    wkbMultiPolygon25D = ogr.wkbMultiPolygon25D  # noqa: N815
    wkbGeometryCollection25D = ogr.wkbGeometryCollection25D  # noqa: N815

    def __init__(self, filename):
        """Initializes the class."""
        gdal.SetConfigOption('GTIFF_REPORT_COMPD_CS', 'TRUE')
        self._vec_ds = None
        gdal_use_exceptions = gdal.GetUseExceptions()
        gdal.UseExceptions()
        try:
            self._vec_ds = gdal.OpenEx(filename, gdal.OF_VECTOR | gdal.OF_READONLY)
        except Exception:
            raise ValueError(f'Unable to open vector file: {filename}')
        if gdal_use_exceptions:
            gdal.UseExceptions()
        else:
            gdal.DontUseExceptions()
        self._min_x = sys.float_info.max
        self._min_y = sys.float_info.max
        self._max_x = -sys.float_info.max
        self._max_y = -sys.float_info.max
        if self._dataset_is_valid():
            for layer_index in range(self._vec_ds.GetLayerCount()):
                layer = self._vec_ds.GetLayerByIndex(layer_index)
                extents = layer.GetExtent()
                self._min_x = min(self._min_x, extents[0])
                self._min_y = min(self._min_y, extents[2])
                self._max_x = max(self._max_x, extents[1])
                self._max_y = max(self._max_y, extents[3])

    @property
    def wkt(self):
        """Returns the vector's well-known text."""
        if self._dataset_is_valid() and self._vec_ds.GetLayerByIndex(0).GetSpatialRef() is not None:
            return self._vec_ds.GetLayerByIndex(0).GetSpatialRef().ExportToWkt()
        return ''

    @property
    def layer_type(self):
        """Returns the vector's data type."""
        if self._dataset_is_valid():
            return self._vec_ds.GetLayerByIndex(0).GetGeomType()
        return ogr.wkbNone

    def _dataset_is_valid(self):
        """Returns whether the dataset is valid."""
        return self._vec_ds is not None and self._vec_ds.GetLayerCount() > 0

    def get_point_features(self):
        """Reads the point features from the vector dataset.

        Returns:
            (list): List of tuples containing points in the vector data file.
        """
        if self._dataset_is_valid():
            return get_point_features(self._vec_ds)
        return []

    def get_line_features(self):
        """Reads the line features from the vector dataset.

        Returns:
            (list): 2D list of tuples containing points making up all the lines in the vector data file.
        """
        if self._dataset_is_valid():
            return get_line_features(self._vec_ds)
        return []

    def get_poly_features(self):
        """Reads the polygon features from the vector dataset.

        Returns:
            (list): 3D list of tuples containing points making up all the polygons in the vector data file.
        """
        if self._dataset_is_valid():
            return get_poly_features(self._vec_ds)
        return []

    def get_attributes(self):
        """Reads the attributes from the vector dataset.

        Returns:
            (dict): A dictionary containing all the attribute data from the dataset.
        """
        attributes = {}
        if self._dataset_is_valid():
            num_layers = self._vec_ds.GetLayerCount()
            for layer_index in range(num_layers):
                layer = self._vec_ds.GetLayerByIndex(layer_index)
                layer.ResetReading()
                layer_def = layer.GetLayerDefn()
                for i in range(layer_def.GetFieldCount()):
                    field_defn = layer_def.GetFieldDefn(i)
                    name = field_defn.name.lower()
                    attributes[name] = []
                    for feature in layer:
                        attributes[name].append(feature.GetField(i))
        return attributes

    def get_fields(self, layer_idx: int = 0) -> list[FieldDefn]:
        """Returns a list of the field definitions.

        Args:
            layer_idx: 0-based index of the layer to get the fields from, defaulted to 0.

        Returns:
            (list): A list of FieldDefn.
        """
        fields = []
        if self._dataset_is_valid() and 0 <= layer_idx < self._vec_ds.GetLayerCount():
            layer = self._vec_ds.GetLayerByIndex(layer_idx)
            layer.ResetReading()
            layer_def = layer.GetLayerDefn()
            fields = [layer_def.GetFieldDefn(i) for i in range(layer_def.GetFieldCount())]
        return fields
