"""Module for the simulation exporter."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"
__all__ = ['export_simulation']

# 1. Standard Python modules
from itertools import pairwise
from pathlib import Path
from typing import Callable, cast, TypeAlias

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.constraint import RectilinearGrid2d
from xms.data_objects.parameters import Arc, UGrid
from xms.gmi.data.generic_model import Section
from xms.guipy.dialogs.feedback_thread import ExpectedError
from xms.snap import SnapInteriorArc

# 4. Local modules

UnsnappedArc: TypeAlias = tuple[Arc, Section]
SnappedArc: TypeAlias = tuple[list[int], Section]


def export_simulation(
    model_control: Section, process_control: Section, grid: RectilinearGrid2d, arcs: list[UnsnappedArc],
    report_progress: Callable[[str], None], where: Path
):
    """
    Export the simulation.

    Args:
        model_control: The simulation's model control parameters.
        process_control: The process control parameters.
        grid: The grid geometry to export.
        arcs: The arcs and their parameters to export.
        report_progress: Function to use for reporting progress.
        where: Directory to export the files into.
    """
    report_progress('Checking grid cells are square...')
    if not all_cells_square(grid):
        raise ExpectedError('Grid cells must be square.')

    report_progress('Writing model control file...')
    control_file_path = where / 'aeolis.txt'
    write_control_file(model_control, process_control, control_file_path)

    report_progress('Writing geometry files...')
    grid_files_group = model_control.group('grid_files')
    x_file = where / (grid_files_group.parameter('xgrid_file').value or 'x.txt')
    y_file = where / (grid_files_group.parameter('ygrid_file').value or 'y.txt')
    z_file = where / (grid_files_group.parameter('bed_file').value or 'z.txt')
    write_grid(grid, x_file, y_file, z_file)

    report_progress('Snapping arcs...')
    snapped_arcs = snap_arcs(arcs, grid)

    if (group := process_control.group('process_fences')).is_active:
        fence_file_name = group.parameter('fence_file').value
        report_progress('Writing sand fences...')
        sand_fence_file = where / fence_file_name
        sand_fence_label = group.parameter('fence_file').label
        write_mask_grid(grid, snapped_arcs, 'sand_fence', sand_fence_file, sand_fence_label)

    if (group := process_control.group('process_wave')).is_active:
        wave_file_name = group.parameter('wave_mask').value
        report_progress('Writing wave boundary conditions...')
        wave_file = where / wave_file_name
        wave_file_label = group.parameter('wave_mask').label
        write_mask_grid(grid, snapped_arcs, 'wave', wave_file, wave_file_label)

        wave_table_name = group.parameter('wave_file').value
        wave_table_file = where / wave_table_name
        wave_table_label = group.parameter('wave_file').label
        wave_table = group.parameter('wave_table').value
        write_table(wave_table, wave_table_label, wave_table_file)

    if (group := process_control.group('process_tide')).is_active:
        tide_mask_name = group.parameter('tide_mask').value
        report_progress('Writing tide boundary conditions...')
        tide_mask_file = where / tide_mask_name
        tide_file_label = group.parameter('tide_mask').label
        write_mask_grid(grid, snapped_arcs, 'tide', tide_mask_file, tide_file_label)

        tide_table_name = group.parameter('tide_file').value
        tide_table_file = where / tide_table_name
        tide_table_label = group.parameter('tide_file').label
        tide_table = group.parameter('tide_series').value
        write_xy_table(tide_table, tide_table_label, tide_table_file)

    report_progress('Done!')


def all_cells_square(grid: RectilinearGrid2d) -> bool:
    """
    Check that all the grid cells are square.

    Args:
        grid: The grid to check.

    Returns:
        Whether all the cells are square.
    """
    x_locations = grid.locations_x
    x_length = abs(x_locations[0] - x_locations[1])
    tolerance = x_length * 1e-10
    if x_length < 1e-10:
        return False

    for first, second in pairwise(x_locations):
        dist = abs(first - second)
        if abs(dist - x_length) > tolerance:
            return False

    y_locations = grid.locations_y
    y_length = abs(y_locations[0] - y_locations[1])
    tolerance = y_length * 1e-10
    if y_length < 1e-10:
        return False

    for first, second in pairwise(y_locations):
        dist = abs(first - second)
        if abs(dist - y_length) > tolerance:
            return False

    if abs(x_length - y_length) > tolerance:
        return False

    return True


def write_control_file(model_control: Section, process_control: Section, where: Path):
    """
    Write the control file to aeolis.txt.

    Args:
        model_control: Global parameters and values.
        process_control: Process parameters and values.
        where: Path to export the control file at.
    """
    skipped_process_control_parameters = {'tide_series', 'wave_table'}
    with open(where, 'w') as f:
        for group_name in model_control.group_names:
            group = model_control.group(group_name)
            f.write(f'% {group.label}\n')
            for parameter_name in group.parameter_names:
                parameter = group.parameter(parameter_name)
                f.write(f'{parameter_name} = {parameter.value}\n')
            f.write('\n')

        for group_name in process_control.group_names:
            group = process_control.group(group_name)
            f.write(f'% {group.label}\n')
            f.write(f'{group_name} = {group.is_active}\n')
            for parameter_name in group.parameter_names:
                if parameter_name not in skipped_process_control_parameters:
                    parameter = group.parameter(parameter_name)
                    f.write(f'{parameter_name} = {parameter.value}\n')
            f.write('\n')


def write_grid(grid: RectilinearGrid2d, x_file: Path, y_file: Path, z_file: Path):
    """
    Write out the main grid.

    Args:
        grid: The grid to write.
        x_file: Path to write the grid's x coordinates.
        y_file: Path to write the grid's y coordinates.
        z_file: Path to write the grid's elevation data.
    """
    locations_x = np.array(grid.locations_x, dtype=float)
    locations_y = np.array(grid.locations_y, dtype=float)

    x_len = len(locations_x)
    y_len = len(locations_y)

    xs = np.broadcast_to(locations_x, (y_len, x_len))
    np.savetxt(x_file, xs)

    ys = np.broadcast_to(locations_y, (x_len, y_len)).transpose()
    np.savetxt(y_file, ys)

    elevations = np.array(grid.point_elevations, dtype=float).reshape((y_len, x_len))
    rotated_elevations = np.flipud(elevations)
    np.savetxt(z_file, rotated_elevations)


def snap_arcs(arcs: list[UnsnappedArc], grid: RectilinearGrid2d) -> list[SnappedArc]:
    """
    Snap arcs to a grid.

    Args:
        arcs: List of arcs and their associated values.
        grid: Grid to snap arcs to.

    Returns:
        List of tuples of (node_indexes, section). node_index is a list of node indexes that the arc snapped to, and
        section is the same one that was associated with the original unsnapped arc.
    """
    snapped_arcs = []

    snapper = SnapInteriorArc()
    snapper.set_grid(cast(UGrid, grid), target_cells=False)
    for arc, section in arcs:
        result = snapper.get_snapped_points(arc)
        nodes = result['id']
        snapped_arcs.append((nodes, section))

    return snapped_arcs


def write_mask_grid(grid: RectilinearGrid2d, arcs: list[SnappedArc], arc_type: str, where: Path, label: str):
    """
    Write out a mask grid.

    Args:
        grid: The grid to write a mask for.
        arcs: All the arcs that may potentially be written to the file.
        arc_type: Name of the group in the arcs' sections that is active if the arc should be written.
        where: Path to write the mask to.
        label: Label to include in error messages identifying the parameter defining the file name.
    """
    filtered = [(arc, section) for arc, section in arcs if section.group(arc_type).is_active]

    x_len = len(grid.locations_x)
    y_len = len(grid.locations_y)

    mask = np.zeros(x_len * y_len)
    for arc, _section in filtered:
        arc = list(arc)
        mask[arc] = 1
    mask_grid = mask.reshape((y_len, x_len)).astype(int)
    rotated_mask_grid = np.flipud(mask_grid)

    try:
        np.savetxt(where, rotated_mask_grid, fmt='%i')
    except IOError:
        raise ExpectedError(f'Unable to create or write to file specified in {label} parameter.')


def write_xy_table(table, label: str, where: str):
    np_table = np.array(table, dtype=float)
    np_table = np_table.transpose()

    try:
        np.savetxt(where, np_table, fmt='%f')
    except IOError:
        raise ExpectedError(f'Unable to create or write to file specified in {label} parameter.')


def write_table(table, label: str, where: str):
    np_table = np.array(table, dtype=float)

    try:
        np.savetxt(where, np_table, fmt='%f')
    except IOError:
        raise ExpectedError(f'Unable to create or write to file specified in {label} parameter.')
