"""Module for the control writer."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"
__all__ = ['write_control']

# 1. Standard Python modules
from datetime import datetime
from pathlib import Path
from typing import Any, Optional, TextIO

# 2. Third party modules

# 3. Aquaveo modules

# 4. Local modules
from xms.ptmio.pcf.program_control import (
    AdvectionMethod, CentroidMethod, EulerianMethod, EulerianTransportMethod, FlowFormat, MeshFormat, NumericalScheme,
    ProgramControl, SedimentFormat, VelocityMethod
)


def write_control(control: ProgramControl, where: str | Path | TextIO):
    """
    Write a program control into a file.

    Args:
        control: The program control to write.
        where: Where to write the program control to.
    """
    if isinstance(where, (str, Path)):
        with open(where, 'w') as f:
            write_control(control, f)
            return

    _write_file_header(where)
    _write_time_block(control, where)
    _write_files_block(control, where)
    _write_computation_methods_block(control, where)
    _write_computation_parameters_block(control, where)
    _write_particle_output_block(control, where)
    _write_mapping_output_block(control, where)
    _write_file_output_block(control, where)

    where.write(':END_DATA\n')


def _write_file_header(where: TextIO):
    """Write the file's header."""
    where.write(
        '##########################################################\n'
        '# PTM Simulation file written by SMS\n'
        '##########################################################\n'
    )


def _write_block_header(name: str, where: TextIO):
    """Write a block header."""
    name += ' '
    where.write('#\n'
                f'# {name:-<56}\n'
                '#\n')


def _write_time_block(control: ProgramControl, where: TextIO):
    """Write the Time block."""
    _write_block_header('Time', where)
    _write_datetime('start_run', control.start_run, where)
    if control.stop_run is not None:
        _write_datetime('stop_run', control.stop_run, where)
    elif control.duration > 0.0:
        _write_card('duration', control.duration, where)
    _write_card('time_step', control.time_step, where)
    _write_card('grid_update', control.grid_update, where)
    _write_card('flow_update', control.flow_update, where)


def _write_files_block(control: ProgramControl, where: TextIO):
    """Write the Files block."""
    _write_block_header('Files', where)
    _write_card('flow_file_xmdf', control.flow_file_xmdf, where)
    _write_flow_format(control.flow_format, where)
    _write_card('xmdf_vel_path', control.xmdf_vel_path, where)
    _write_card('xmdf_wse_path', control.xmdf_wse_path, where)
    if control.start_flow is not None:
        _write_datetime('start_flow', control.start_flow, where)
    _write_card('bc_file', control.bc_file, where)
    _write_card('mesh_file', control.mesh_file, where)
    _write_mesh_format(control.mesh_format, where)
    _write_card('neighbor_file', control.neighbor_file, where)
    _write_sediment_format(control.sediment_format, where)
    _write_card('sediment_file', control.sediment_file, where)
    if control.sediment_format == SedimentFormat.xmdf_dataset:
        _write_card('xmdf_d35_path', control.xmdf_d35_path, where)
        _write_card('xmdf_d50_path', control.xmdf_d50_path, where)
        _write_card('xmdf_d90_path', control.xmdf_d90_path, where)
    _write_card('source_file', control.source_file, where)
    _write_card('trap_file', control.trap_file, where)
    _write_datetime('start_trap', control.start_trap, where)
    _write_datetime('stop_trap', control.stop_trap, where)
    _write_card('wave_step', control.wave_step, where)
    _write_datetime('start_waves', control.start_waves, where)
    _write_card('output_prefix', control.output_prefix, where)


def _write_computation_methods_block(control: ProgramControl, where: TextIO):
    """Write the Computation Methods block."""
    _write_block_header('Computation Methods', where)
    _write_advection_method(control.advection_method, where)
    _write_centroid_method(control.centroid_method, where)
    _write_bool('by_weight', control.by_weight, where)
    _write_eulerian_method(control.eulerian_method, where)
    _write_velocity_method(control.velocity_method, where)
    _write_eulerian_transport_method(control.eulerian_transport_method, where)
    _write_numerical_scheme(control.numerical_scheme, where)


def _write_computation_parameters_block(control: ProgramControl, where: TextIO):
    """Write the Computation Parameters block."""
    _write_block_header('Computation Parameters', where)
    _write_card('bed_porosity', control.bed_porosity, where)
    _write_card('rhos', control.rhos, where)
    _write_card('min_depth', control.min_depth, where)
    _write_card('temperature', control.temperature, where)
    _write_card('ket', control.ket, where)
    _write_card('salinity', control.salinity, where)
    _write_card('kew', control.kew, where)
    _write_card('kev', control.kev, where)
    _write_card('etmin', control.etmin, where)
    _write_card('evmin', control.evmin, where)
    _write_bool('currents', control.currents, where)
    _write_bool('morphology', control.morphology, where)
    _write_bool('last_step_trap', control.last_step_trap, where)
    _write_bool('neutrally_buoyant', control.neutrally_buoyant, where)
    _write_bool('no_bedforms', not control.bedforms, where)
    _write_bool('no_hiding_exposure', not control.hiding_exposure, where)
    _write_bool('no_bed_interaction', not control.bed_interaction, where)
    _write_bool('no_turbulent_shear', not control.turbulent_shear, where)
    _write_bool('residence_calc', control.residence_calc, where)
    _write_bool('source_to_datum', control.source_to_datum, where)
    _write_bool('no_wave_mass_transport', not control.wave_mass_transport, where)


def _write_particle_output_block(control: ProgramControl, where: TextIO):
    """Write the Particle Output block."""
    _write_block_header('Particle Output', where)
    _write_card('output_inc', control.output_inc, where)
    _write_bool('tau_cr_output', control.tau_cr_output, where)
    _write_bool('elevation_output', True, where)
    _write_bool('fall_velocity_output', control.fall_velocity_output, where)
    _write_bool('grain_size_output', control.grain_size_output, where)
    _write_bool('height_output', control.height_output, where)
    _write_bool('id_output', True, where)
    _write_bool('density_output', control.density_output, where)
    _write_bool('parcel_mass_output', control.parcel_mass_output, where)
    _write_bool('mobility_output', control.mobility_output, where)
    _write_bool('paths', control.paths, where)
    _write_bool('source_output', control.source_output, where)
    _write_bool('state_output', control.state_output, where)
    _write_bool('flow_output', control.flow_output, where)


def _write_mapping_output_block(control: ProgramControl, where: TextIO):
    """Write the Mapping Output block."""
    _write_block_header('Mapping Output', where)
    _write_card('mapping_inc', control.mapping_inc, where)
    _write_bool('bedform_mapping', control.bedform_mapping, where)
    _write_bool('bed_level_mapping', control.bed_level_mapping, where)
    _write_bool('bed_level_change_mapping', control.bed_level_change_mapping, where)
    _write_bool('flow_mapping', control.flow_mapping, where)
    _write_bool('mobility_mapping', control.mobility_mapping, where)
    _write_bool('transport_mapping', control.transport_mapping, where)
    _write_bool('shear_stress_mapping', control.shear_stress_mapping, where)
    _write_bool('wave_mapping', control.wave_mapping, where)


def _write_file_output_block(control: ProgramControl, where: TextIO):
    """Write the File Output block."""
    _write_block_header('File Output', where)
    _write_bool('tecplot_maps', control.tecplot_maps, where)
    _write_bool('tecplot_parcels', control.tecplot_parcels, where)
    _write_bool('population_record', control.population_record, where)
    _write_bool('xmdf_compressed', control.xmdf_compressed, where)


def _write_card(card: str, value: Any, where: TextIO):
    """
    Write a card to the file.

    The card will be formatted to take up about 30 columns, so that values all line up.

    Args:
        card: The card to write. Will be UPPERCASED on write.
        value: The card's value.
        where: Where to write the output.
    """
    assert not isinstance(value, bool)
    where.write(f':{card.upper():29} {value}\n')


def _write_bool(card: str, value: bool, where: TextIO):
    """Write a keyword with a boolean value."""
    assert isinstance(value, bool)
    if not value:
        return
    where.write(f':{card.upper()}\n')


def _write_datetime(card: str, value: Optional[datetime], where: TextIO):
    """Write a keyword with a datetime value."""
    value = (
        f'{value.year:4d} {value.month:2d} {value.day:2d} '
        f'{value.hour:2d} {value.minute:2d} {value.second:2d}'
    )
    _write_card(card, value, where)


def _write_advection_method(value: AdvectionMethod, where: TextIO):
    """Write the ADVECTION_METHOD keyword."""
    if value == AdvectionMethod.one_d:
        write = '1D'
    elif value == AdvectionMethod.two_d:
        write = '2D'
    elif value == AdvectionMethod.three_d:
        write = '3D'
    elif value == AdvectionMethod.q_three_d:
        write = 'Q3D'
    else:
        raise AssertionError('Unsupported value for AdvectionMethod')

    _write_card('ADVECTION_METHOD', write, where)


# def _write_bottom_flow_format(value: BottomFlowFormat, where: TextIO):
#     """Write the BOTTOM_FLOW_FORMAT keyword."""
#     if value == BottomFlowFormat.adcirc:
#         write = 'ADCIRC'
#     elif value == BottomFlowFormat.xmdf:
#         write = 'XMDF'
#     else:
#         raise AssertionError('Unsupported value for BottomFlowFormat')
#
#     _write_card('BOTTOM_FLOW_FORMAT', write, where)


def _write_centroid_method(value: CentroidMethod, where: TextIO):
    """Write the CENTROID_METHOD keyword."""
    if value == CentroidMethod.rouse:
        write = 'ROUSE'
    elif value == CentroidMethod.van_rijn:
        write = 'VAN_RIJN'
    else:
        raise AssertionError('Unsupported value for CentroidMethod')

    _write_card('CENTROID_METHOD', write, where)


def _write_eulerian_method(value: EulerianMethod, where: TextIO):
    """Write the EULERIAN_METHOD keyword."""
    if value == EulerianMethod.ptm:
        write = 'PTM'
    elif value == EulerianMethod.van_rijn:
        write = 'VAN_RIJN'
    else:
        raise AssertionError('Unsupported value for EulerianMethod')

    _write_card('EULERIAN_METHOD', write, where)


def _write_eulerian_transport_method(value: EulerianTransportMethod, where: TextIO):
    """Write the EULERIAN_SED_TRANS keyword."""
    if value == EulerianTransportMethod.soulsby_van_rijn:
        write = 'SOULSBY-VAN_RIJN'
    elif value == EulerianTransportMethod.van_rijn:
        write = 'VAN_RIJN'
    elif value == EulerianTransportMethod.lund:
        write = 'LUND'
    elif value == EulerianTransportMethod.camenen_larson:
        write = 'CAMENEN_LARSON'
    else:
        raise AssertionError('Unsupported value for EulerianTransportMethod')

    _write_card('EULERIAN_SED_TRANS', write, where)


def _write_flow_format(value: FlowFormat, where: TextIO):
    """Write the FLOW_FORMAT keyword."""
    if value == FlowFormat.adcirc_ascii:
        write = 'ADCIRC'
    elif value == FlowFormat.adcirc_xmdf:
        write = 'XMDF'
    elif value == FlowFormat.cmsflow_2d_single:
        write = 'CMS-2D'
    elif value == FlowFormat.cmsflow_2d_multi:
        write = 'CMS-2D-MULTI'
    elif value == FlowFormat.adh:
        write = 'ADH'
    else:
        raise AssertionError('Unsupported value for FlowFormat')

    _write_card('FLOW_FORMAT', write, where)


def _write_mesh_format(value: MeshFormat, where: TextIO):
    """Write the MESH_FORMAT keyword."""
    if value == MeshFormat.adcirc:
        write = 'ADCIRC'
    elif value == MeshFormat.cms_2d:
        write = 'CMS-2D'
    else:
        raise AssertionError('Unsupported value for MeshFormat')

    _write_card('MESH_FORMAT', write, where)


def _write_numerical_scheme(value: NumericalScheme, where: TextIO):
    """Write the NUMERICAL_SCHEME keyword."""
    if value == NumericalScheme.two:
        write = '2'
    elif value == NumericalScheme.four:
        write = '4'
    else:
        raise AssertionError('Unsupported value for NumericalScheme')

    _write_card('NUMERICAL_SCHEME', write, where)


def _write_sediment_format(value: SedimentFormat, where: TextIO):
    """Write the SEDIMENT_FORMAT keyword."""
    if value == SedimentFormat.adcirc:
        write = 'ADCIRC'
    elif value == SedimentFormat.xmdf_dataset:
        write = 'XMDF'
    elif value == SedimentFormat.m2d:
        write = 'CMS-2D'
    elif value == SedimentFormat.uniform:
        write = 'UNIFORM'
    else:
        raise AssertionError('Unsupported value for SedimentFormat')

    _write_card('SEDIMENT_FORMAT', write, where)


def _write_velocity_method(value: VelocityMethod, where: TextIO):
    """Write the VELOCITY_METHOD keyword."""
    if value == VelocityMethod.two_d_log:
        write = '2D (LOGARITHMIC)'
    elif value == VelocityMethod.two_d_two_point:
        write = 'TWO-POINT'
    elif value == VelocityMethod.two_d_uniform:
        write = 'UNIFORM'
    elif value == VelocityMethod.three_d_sigma:
        write = '3D'
    elif value == VelocityMethod.three_d_z:
        write = '3DZ'
    else:
        raise AssertionError('Unsupported value for VelocityMethod')

    _write_card('VELOCITY_METHOD', write, where)


# def _write_wave_format(value: WaveFormat, where: TextIO):
#     """Write the WAVE_FORMAT keyword."""
#     if value == WaveFormat.stwave:
#         write = 'STWAVE'
#     elif value == WaveFormat.wabed:
#         write = 'WABED'
#     elif value == WaveFormat.xmdf:
#         write = 'XMDF'
#     elif value == WaveFormat.cms_wave:
#         write = 'CMS-WAVE'
#     else:
#         raise AssertionError('Unsupported value for WaveFormat')
#
#     _write_card('WAVE_FORMAT', write, where)
