"""Module for converting high-level sources to the low-level writer's format and running the low-level writer."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"
__all__ = ['write_sources']

# 1. Standard Python modules
import itertools
import logging
from pathlib import Path
from typing import TextIO

# 2. Third party modules

# 3. Aquaveo modules
from xms.coverage.polygons.polygon_orienteer import get_poly_points, make_poly_counter_clockwise
from xms.data_objects.parameters import Arc as DoArc, Coverage, FilterLocation, Point as DoPoint, Polygon as DoPolygon
from xms.gmi.data.generic_model import Section
from xms.gmi.data_bases.coverage_base_data import CoverageBaseData
from xms.guipy.data.target_type import TargetType
from xms.guipy.time_format import string_to_datetime
from xms.ptmio.source.source_writer import write_sources as pio_write
from xms.ptmio.source.sources import (
    InstantMassInstruction, InstantMassSource, LineMassDatum, LineMassInstruction, LineMassSource, PointMassInstruction,
    PointMassSource, PolygonMassInstruction, PolygonMassSource, Sources as PioSources
)

# 4. Local modules
from xms.ptm.model.model import source_model

logger = logging.getLogger('xms.ptm')


def write_sources(coverage: Coverage, data: CoverageBaseData, end_time: str, where: str | Path | TextIO):
    """
    Write the sources to a file.

    Args:
        coverage: Coverage containing geometry to write.
        data: Data manager with the attributes for the geometry.
        end_time: The simulation's end time, in the format used by the instruction dates.
        where: Where to write them.
    """
    pio_sources = PioSources()

    _convert_points(coverage, data, end_time, pio_sources)
    _convert_h_lines(coverage, data, end_time, pio_sources)
    _convert_polygons(coverage, data, end_time, pio_sources)

    pio_write(pio_sources, where)


def _convert_points(coverage: Coverage, data: CoverageBaseData, end_time: str, sources: PioSources):
    """Convert all the instants and their attributes to the low-level writer's format."""
    points = coverage.get_points(FilterLocation.PT_LOC_DISJOINT)
    attributes = source_model().point_parameters

    for point in points:
        values = data.feature_values(TargetType.point, feature_id=point.id)
        attributes.restore_values(values)
        if attributes.group('instant').is_active:
            _fix_instant(attributes, end_time)
            _convert_instant(attributes, sources)
        elif attributes.group('point').is_active:
            _fix_point(attributes, end_time)
            _convert_point(attributes, sources)
        elif attributes.group('vertical_line').is_active:
            _convert_v_lines(point, attributes, end_time, sources)
        else:
            assert not attributes.active_group_names


def _fix_instant(attributes: Section, end_time: str):
    """Fix up issues that can be automatically fixed."""
    table = attributes.group('instant').parameter('instructions').value
    table.sort()  # The user doesn't have to put things in order. Ensure they're ordered by date.
    _patch_end_time(table, end_time)
    attributes.group('instant').parameter('instructions').value = table


def _convert_instant(attributes: Section, sources: PioSources):
    """Convert an instant and its attributes to the low-level writer's format."""
    attributes = attributes.group('instant')
    assert attributes.is_active

    source_id = attributes.parameter('id').value
    label = attributes.parameter('label').value
    table = attributes.parameter('instructions').value

    source = InstantMassSource(source_id=source_id, label=label)

    for row in table:
        time = string_to_datetime(row[0])
        x = row[1]
        y = row[2]
        z = row[3]
        mass = row[4]
        h = row[5]
        v = row[6]
        rate = row[7]
        size = row[8]
        stdev = row[9]
        density = row[10]
        velocity = row[11]
        initiation = row[12]
        deposition = row[13]

        instruction = InstantMassInstruction(
            time=time,
            location=(x, y, z),
            parcel_mass=mass,
            h_radius=h,
            v_radius=v,
            mass_rate=rate,
            grain_size=size,
            stdev=stdev,
            density=density,
            velocity=velocity,
            initiation=initiation,
            deposition=deposition
        )
        source.instructions.append(instruction)

    sources.instant_sources.append(source)


def _fix_point(attributes: Section, end_time: str):
    """Fix up issues that can be automatically fixed."""
    table = attributes.group('point').parameter('instructions').value
    table.sort()  # The user doesn't have to put things in order. Ensure they're ordered by date.
    _patch_end_time(table, end_time)

    attributes.group('point').parameter('instructions').value = table


def _convert_point(attributes: Section, sources: PioSources):
    """Convert a point and its attributes to the low-level writer's format."""
    attributes = attributes.group('point')
    assert attributes.is_active

    source_id = attributes.parameter('id').value
    label = attributes.parameter('label').value
    table = attributes.parameter('instructions').value
    table.sort()  # The user doesn't have to put things in order. Ensure they're ordered by date.

    source = PointMassSource(source_id=source_id, label=label)

    for row in table:
        time = string_to_datetime(row[0])
        x = row[1]
        y = row[2]
        z = row[3]
        mass = row[4]
        h = row[5]
        v = row[6]
        rate = row[7]
        size = row[8]
        stdev = row[9]
        density = row[10]
        velocity = row[11]
        initiation = row[12]
        deposition = row[13]

        instruction = PointMassInstruction(
            time=time,
            location=(x, y, z),
            parcel_mass=mass,
            h_radius=h,
            v_radius=v,
            mass_rate=rate,
            grain_size=size,
            stdev=stdev,
            density=density,
            velocity=velocity,
            initiation=initiation,
            deposition=deposition
        )
        source.instructions.append(instruction)

    sources.point_sources.append(source)


def _convert_v_lines(point: DoPoint, attributes: Section, end_time: str, sources: PioSources):
    """Convert all the vertical lines at one point to the low-level writer's format."""
    table = attributes.group('vertical_line').parameter('instructions').value

    table.sort(key=lambda row: (row[1], row[3]))
    for _, instructions in itertools.groupby(table, lambda row: row[1]):
        line_section = attributes.copy()
        line_section.group('vertical_line').parameter('instructions').value = list(instructions)
        _fix_v_line(line_section, end_time)
        _check_v_line(point, line_section)
        _convert_v_line(line_section, sources)


def _fix_v_line(attributes: Section, end_time: str):
    """Fix up issues that can be automatically fixed."""
    table = attributes.group('vertical_line').parameter('instructions').value
    table.sort()  # The user doesn't have to put things in order. Ensure they're ordered by date.
    _patch_end_time(table, end_time, time_column=3)
    attributes.group('vertical_line').parameter('instructions').value = table


def _check_v_line(point: DoPoint, attributes: Section):
    """Check that a vertical line is valid and log warnings if not."""
    table = attributes.group('vertical_line').parameter('instructions').value
    instruction = table[0]

    x = instruction[4]
    y = instruction[5]
    segment_id = instruction[1]
    if x != point.x or y != point.y:
        logger.warning(
            f"The location of point {point.id} does not match the first instruction for segment {segment_id}. "
            "The point's visual location may not match what PTM sees."
        )

    names = {row[0] for row in table}
    if len(names) > 1:
        logger.warning(
            f'Instructions for vertical line segment {segment_id} on point {point.id} use multiple names. The name '
            'on the instruction with the earliest date will be written to the .source file.'
        )


def _convert_v_line(attributes: Section, sources: PioSources):
    """Convert a vertical line and its attributes to the low-level writer's format."""
    attributes = attributes.group('vertical_line')
    assert attributes.is_active

    table: list[list] = attributes.parameter('instructions').value

    mapping = {
        'Bed Datum': LineMassDatum.bed_datum,
        'Surface Datum': LineMassDatum.surface_datum,
        'Depth Distributed': LineMassDatum.depth_distributed,
    }
    label = table[0][0]
    source_id = table[0][1]
    datum = table[0][2]
    datum = mapping[datum]
    source = LineMassSource(source_id=source_id, label=label, datum=datum)

    for row in table:
        time = string_to_datetime(row[3])
        x = row[4]
        y = row[5]
        bottom = row[6]
        top = row[7]
        mass = row[8]
        radius = row[9]
        rate = row[10]
        size = row[11]
        stdev = row[12]
        density = row[13]
        velocity = row[14]
        initiation = row[15]
        deposition = row[16]
        instruction = LineMassInstruction(
            time=time,
            start=(x, y, bottom),
            end=(x, y, top),
            parcel_mass=mass,
            h_radius=radius,
            v_radius=radius,
            mass_rate=rate,
            grain_size=size,
            stdev=stdev,
            density=density,
            velocity=velocity,
            initiation=initiation,
            deposition=deposition
        )
        source.instructions.append(instruction)

    sources.line_sources.append(source)


def _convert_h_lines(coverage: Coverage, data: CoverageBaseData, end_time: str, sources: PioSources):
    """Convert all the horizontal lines and their attributes to the low-level writer's format."""
    arcs = coverage.arcs
    attributes = source_model().arc_parameters

    for arc in arcs:
        values = data.feature_values(TargetType.arc, feature_id=arc.id)
        attributes.restore_values(values)
        if attributes.group('horizontal_line').is_active:
            _fix_h_line(attributes, end_time)
            _check_h_line(arc, attributes)
            _convert_h_line(attributes, sources)
        else:
            assert not attributes.active_group_names


def _fix_h_line(attributes: Section, end_time: str):
    """Fix up issues that can be automatically fixed."""
    table = attributes.group('horizontal_line').parameter('instructions').value
    table.sort()  # The user doesn't have to put things in order. Ensure they're ordered by date.
    _patch_end_time(table, end_time)
    attributes.group('horizontal_line').parameter('instructions').value = table


def _check_h_line(arc: DoArc, attributes: Section):
    """Check that an arc is valid and log warnings if not."""
    if arc.vertices:
        logger.warning(
            f'Arc {arc.id} has a vertex. The PTM interface does not support vertices. The vertex was discarded.'
        )

    table = attributes.group('horizontal_line').parameter('instructions').value
    table.sort()

    x1, y1, x2, y2, z = table[0][1:6]
    start = arc.start_node
    end = arc.end_node
    if x1 != start.x or y1 != start.y or z != start.z or x2 != end.x or y2 != end.y or z != end.z:
        logger.warning(
            f"Node locations on arc {arc.id} do not match the first instruction. The arc's visual location may"
            'not match what PTM sees.'
        )


def _convert_h_line(attributes: Section, sources: PioSources):
    """Convert a horizontal line and its attributes to the low-level writer's format."""
    attributes = attributes.group('horizontal_line')
    assert attributes.is_active

    source_id = attributes.parameter('id').value
    label = attributes.parameter('label').value
    table = attributes.parameter('instructions').value
    table.sort()  # The user doesn't have to put things in order. Ensure they're ordered by date.

    source = LineMassSource(source_id=source_id, label=label, datum=LineMassDatum.none)

    for row in table:
        time = string_to_datetime(row[0])
        x1 = row[1]
        y1 = row[2]
        x2 = row[3]
        y2 = row[4]
        z = row[5]
        mass = row[6]
        radius = row[7]
        rate = row[8]
        size = row[9]
        stdev = row[10]
        density = row[11]
        velocity = row[12]
        initiation = row[13]
        deposition = row[14]

        instruction = LineMassInstruction(
            time=time,
            start=(x1, y1, z),
            end=(x2, y2, z),
            parcel_mass=mass,
            h_radius=radius,
            v_radius=radius,
            mass_rate=rate,
            grain_size=size,
            stdev=stdev,
            density=density,
            velocity=velocity,
            initiation=initiation,
            deposition=deposition
        )
        source.instructions.append(instruction)

    sources.line_sources.append(source)


def _convert_polygons(coverage: Coverage, data: CoverageBaseData, end_time: str, sources: PioSources):
    """Convert all the polygons and their attributes to the low-level writer's format."""
    polygons = coverage.polygons
    attributes = source_model().polygon_parameters

    for polygon in polygons:
        values = data.feature_values(TargetType.polygon, feature_id=polygon.id)
        attributes.restore_values(values)
        if attributes.group('polygon').is_active:
            _fix_polygon(attributes, end_time)
            _check_polygon(polygon, attributes)
            _convert_polygon(polygon, attributes, sources)
        else:
            assert not attributes.active_group_names


def _fix_polygon(attributes: Section, end_time: str):
    """Fix up issues that can be automatically fixed."""
    table = attributes.group('polygon').parameter('instructions').value
    table.sort()  # The user doesn't have to put things in order. Ensure they're ordered by date.
    _patch_end_time(table, end_time)
    attributes.group('polygon').parameter('instructions').value = table


def _check_polygon(polygon: DoPolygon, attributes: Section):
    """Check that a polygon is valid and log warnings if not."""
    if polygon.interior_arcs:
        logger.warning(
            f'Polygon {polygon.id} has a hole. The PTM interface does not support polygons with holes. '
            'The hole was discarded.'
        )

    table = attributes.group('polygon').parameter('instructions').value
    table.sort()

    z = table[0][1]
    locations = get_poly_points(polygon.arcs, polygon.arc_directions)
    if not all(location[2] == z for location in locations):
        logger.warning(
            f"Point elevations on polygon {polygon.id} do not match the first instruction. The polygon's visual "
            'location may not match what PTM sees.'
        )


def _convert_polygon(do_polygon: DoPolygon, attributes: Section, sources: PioSources):
    """Convert a polygon and its attributes to the low-level writer's format."""
    attributes = attributes.group('polygon')
    assert attributes.is_active

    source_id = attributes.parameter('id').value
    label = attributes.parameter('label').value
    table = attributes.parameter('instructions').value
    table.sort()  # The user doesn't have to put things in order. Ensure they're ordered by date.

    source = PolygonMassSource(source_id=source_id, label=label)

    original_polygon = get_poly_points(do_polygon.arcs, do_polygon.arc_directions, inner=False)
    make_poly_counter_clockwise(original_polygon)
    original_polygon.pop()  # Discard duplicate end point

    for row in table:
        time = string_to_datetime(row[0])
        elevation = row[1]
        mass = row[2]
        radius = row[3]
        rate = row[4]
        size = row[5]
        stdev = row[6]
        density = row[7]
        velocity = row[8]
        initiation = row[9]
        deposition = row[10]

        polygon = [(x, y, elevation) for (x, y, _z) in original_polygon]

        instruction = PolygonMassInstruction(
            time=time,
            points=polygon,
            parcel_mass=mass,
            h_radius=radius,
            v_radius=radius,
            mass_rate=rate,
            grain_size=size,
            stdev=stdev,
            density=density,
            velocity=velocity,
            initiation=initiation,
            deposition=deposition
        )
        source.instructions.append(instruction)

    sources.polygon_sources.append(source)


def _patch_end_time(table: list, end_time: str, time_column: int = 0):
    """Ensure the last instruction is at or after the end of the simulation."""
    if table[-1][time_column] < end_time:
        extra_row = table[-1][:]
        extra_row[time_column] = end_time
        table.append(extra_row)
