"""Module for reading Program Control Files."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"
__all__ = ['read_control']

# 1. Standard Python modules
from datetime import datetime
from pathlib import Path
from typing import TextIO

# 2. Third party modules

# 3. Aquaveo modules

# 4. Local modules
from xms.ptmio.file_reader import ParseError
from xms.ptmio.pcf.program_control import (
    AdvectionMethod, CentroidMethod, EulerianMethod, EulerianTransportMethod, FlowFormat, MeshFormat, NumericalScheme,
    ProgramControl, SedimentFormat, VelocityMethod
)


def read_control(file: str | Path | TextIO) -> ProgramControl:
    """
    Read a control file.

    Only low-level validation is done, e.g. whether the file is syntactically correct, keywords are recognized, and
    keywords that are present have valid values. No effort is made to ensure that mandatory keywords were present.
    Some keywords have enumerated values, and these are validated. Most keywords have open-ended values, and no effort
    is made to check whether they make sense (e.g. keywords may specify paths that don't exist, or floats may be values
    raised to the 305th power, and the reader doesn't care).

    Args:
        file: The file to read. Can be either a path to the file or an already opened file.

    Returns:
        The read program control.
    """
    if isinstance(file, (str, Path)):
        with open(file) as f:
            return read_control(f)

    control = ProgramControl()
    for line_number, line in enumerate(file, start=1):
        keep_going = _parse_line(line, line_number, control)
        if not keep_going:
            break

    return control


def _parse_line(line: str, line_number: int, control: ProgramControl) -> bool:
    """
    Parse a line in the file and assign the parsed value into a ProgramControl.

    Args:
        line: The line to parse.
        line_number: Line number in the file. Used for error reporting.
        control: Where to store the parsed result.

    Returns:
        Whether to continue reading the file. Will be True in most cases, but may be
        False if the end card is parsed.
    """
    if line.startswith('#'):
        return True

    pieces = line.split(maxsplit=1)
    if not pieces:
        raise ParseError(line_number, 1, 'Line was blank.')

    key = pieces[0].lower().strip()
    value = pieces[1] if len(pieces) > 1 else ''
    value = value.rstrip('\n')
    if key == ':bedform_mapping':
        control.bedform_mapping = _parse_bool(line_number, value)
    elif key == ':bed_level_change_mapping':
        control.bed_level_change_mapping = _parse_bool(line_number, value)
    elif key == ':bed_level_mapping':
        control.bed_level_mapping = _parse_bool(line_number, value)
    elif key == ':by_weight':
        control.by_weight = _parse_bool(line_number, value)
    elif key == ':currents':
        control.currents = _parse_bool(line_number, value)
    # elif key == ':debug':
    #     control.debug = _parse_bool(line_number, value)
    elif key == ':density_output':
        control.density_output = _parse_bool(line_number, value)
    elif key == ':elevation_output':
        pass  # We hard-code this to True since SMS doesn't support anything else.
    # elif key == ':ensim':
    #     control.ensim = _parse_bool(line_number, value)
    # elif key == ':ensim_maps':
    #     control.ensim_maps = _parse_bool(line_number, value)
    # elif key == ':ensim_parcels':
    #     control.ensim_parcels = _parse_bool(line_number, value)
    # elif key == ':exceed_waves':
    #     control.exceed_waves = _parse_bool(line_number, value)
    elif key == ':fall_velocity_output':
        control.fall_velocity_output = _parse_bool(line_number, value)
    elif key == ':flow_mapping':
        control.flow_mapping = _parse_bool(line_number, value)
    elif key == ':flow_output':
        control.flow_output = _parse_bool(line_number, value)
    elif key == ':grain_size_output':
        control.grain_size_output = _parse_bool(line_number, value)
    elif key == ':height_output':
        control.height_output = _parse_bool(line_number, value)
    elif key == ':id_output':
        pass  # We hard-code this to True since XMS doesn't support anything else.
    elif key == ':last_step_trap':
        control.last_step_trap = _parse_bool(line_number, value)
    elif key == ':mobility_mapping':
        control.mobility_mapping = _parse_bool(line_number, value)
    elif key == ':mobility_output':
        control.mobility_output = _parse_bool(line_number, value)
    elif key == ':morphology':
        control.morphology = _parse_bool(line_number, value)
    # elif key == ':nested':
    #     control.nested = _parse_bool(line_number, value)
    elif key == ':neutrally_buoyant':
        control.neutrally_buoyant = _parse_bool(line_number, value)
    elif key == ':no_bedforms':
        control.bedforms = not _parse_bool(line_number, value)
    elif key == ':no_bed_interaction':
        control.bed_interaction = not _parse_bool(line_number, value)
    elif key == ':no_hiding_exposure':
        control.hiding_exposure = not _parse_bool(line_number, value)
    # elif key == ':no_parcels':
    #     control.parcels = not _parse_bool(line_number, value)
    elif key == ':no_turbulent_shear':
        control.turbulent_shear = not _parse_bool(line_number, value)
    elif key == ':no_wave_mass_transport':
        control.wave_mass_transport = not _parse_bool(line_number, value)
    # elif key == ':no_xmdf_maps':
    #     raise ParseError(line_number, ':NO_XMDF_MAPS keyword is unsupported.')
    # elif key == ':no_xmdf_parcels':
    #     raise ParseError(line_number, ':NO_XMDF_PARCELS keyword is unsupported.')
    elif key == ':parcel_mass_output':
        control.parcel_mass_output = _parse_bool(line_number, value)
    elif key == ':paths':
        control.paths = _parse_bool(line_number, value)
    elif key == ':population_record':
        control.population_record = _parse_bool(line_number, value)
    elif key == ':residence_calc':
        control.residence_calc = _parse_bool(line_number, value)
    elif key == ':shear_stress_mapping':
        control.shear_stress_mapping = _parse_bool(line_number, value)
    elif key == ':source_output':
        control.source_output = _parse_bool(line_number, value)
    elif key == ':source_to_datum':
        control.source_to_datum = _parse_bool(line_number, value)
    elif key == ':state_output':
        control.state_output = _parse_bool(line_number, value)
    elif key == ':tau_cr_output':
        control.tau_cr_output = _parse_bool(line_number, value)
    elif key == ':tecplot_maps':
        control.tecplot_maps = _parse_bool(line_number, value)
    elif key == ':tecplot_parcels':
        control.tecplot_parcels = _parse_bool(line_number, value)
    elif key == ':transport_mapping':
        control.transport_mapping = _parse_bool(line_number, value)
    # elif key == ':wave_breaking':
    #     control.wave_breaking = _parse_bool(line_number, value)
    elif key == ':wave_mapping':
        control.wave_mapping = _parse_bool(line_number, value)
    # elif key == ':wave_mass_transport':
    #     control.wave_mass_transport = _parse_bool(line_number, value)
    # elif key == ':waves':
    #     control.waves = _parse_bool(line_number, value)
    elif key == ':xmdf_compressed':
        control.xmdf_compressed = _parse_bool(line_number, value)
    elif key == ':flow_update':
        control.flow_update = _parse_int(line_number, value)
    elif key == ':grid_update':
        control.grid_update = _parse_int(line_number, value)
    elif key == ':mapping_inc':
        control.mapping_inc = _parse_int(line_number, value)
    elif key == ':output_inc':
        control.output_inc = _parse_int(line_number, value)
    # elif key == ':wave_files':
    #     control.wave_files = _parse_int(line_number, value)
    # elif key == ':wave_frames':
    #     control.wave_frames = _parse_int(line_number, value)
    elif key == ':bed_porosity':
        control.bed_porosity = _parse_float(line_number, value)
    # elif key == ':bottom_flow_height':
    #     control.bottom_flow_height = _parse_float(line_number, value)
    elif key == ':duration':
        control.duration = _parse_float(line_number, value)
    elif key == ':etmin':
        control.etmin = _parse_float(line_number, value)
    elif key == ':evmin':
        control.evmin = _parse_float(line_number, value)
    elif key == ':ket':
        control.ket = _parse_float(line_number, value)
    elif key == ':kev':
        control.kev = _parse_float(line_number, value)
    elif key == ':kew':
        control.kew = _parse_float(line_number, value)
    # elif key == ':m2d_step':
    #     control.m2d_step = _parse_float(line_number, value)
    elif key == ':min_depth':
        control.min_depth = _parse_float(line_number, value)
    elif key == ':rhos':
        control.rhos = _parse_float(line_number, value)
    elif key == ':salinity':
        control.salinity = _parse_float(line_number, value)
    elif key == ':temperature':
        control.temperature = _parse_float(line_number, value)
    elif key == ':time_step':
        control.time_step = _parse_float(line_number, value)
    # elif key == ':wave_grid_angle':
    #     control.wave_grid_angle = _parse_float(line_number, value)
    # elif key == ':wave_grid_angle_2':
    #     control.wave_grid_angle_2 = _parse_float(line_number, value)
    elif key == ':wave_step':
        control.wave_step = _parse_float(line_number, value)
    # elif key == ':wave_x_origin':
    #     control.wave_x_origin = _parse_float(line_number, value)
    # elif key == ':wave_x_origin_2':
    #     control.wave_x_origin_2 = _parse_float(line_number, value)
    # elif key == ':wave_y_origin':
    #     control.wave_y_origin = _parse_float(line_number, value)
    # elif key == ':wave_y_origin_2':
    #     control.wave_y_origin_2 = _parse_float(line_number, value)
    elif key == ':bc_file':
        control.bc_file = _parse_str(line_number, value)
    # elif key == ':bottom_flow_file':
    #     control.bottom_flow_file = _parse_str(line_number, value)
    # elif key == ':bottom_mask_file':
    #     control.bottom_mask_file = _parse_str(line_number, value)
    # elif key == ':flow_file_dst':
    #     control.flow_file_dst = _parse_str(line_number, value)
    # elif key == ':flow_file_uv':
    #     control.flow_file_uv = _parse_str(line_number, value)
    elif key == ':flow_file_xmdf':
        control.flow_file_xmdf = _parse_str(line_number, value)
    # elif key == ':flow_file_z':
    #     control.flow_file_z = _parse_str(line_number, value)
    # elif key == ':m2d_control_file':
    #     control.m2d_control_file = _parse_str(line_number, value)
    elif key == ':mesh_file':
        control.mesh_file = _parse_str(line_number, value)
    elif key == ':neighbor_file':
        control.neighbor_file = _parse_str(line_number, value)
    elif key == ':output_prefix':
        control.output_prefix = _parse_str(line_number, value)
    elif key == ':sediment_file':
        control.sediment_file = _parse_str(line_number, value)
    elif key == ':source_file':
        control.source_file = _parse_str(line_number, value)
    elif key == ':trap_file':
        control.trap_file = _parse_str(line_number, value)
    # elif key == ':wave_file_2_xmdf':
    #     control.wave_file_2_xmdf = _parse_str(line_number, value)
    # elif key == ':wave_file_xmdf':
    #     control.wave_file_xmdf = _parse_str(line_number, value)
    # elif key == ':xmdf_bot_path':
    #     control.xmdf_bot_path = _parse_str(line_number, value)
    elif key == ':xmdf_d35_path':
        control.xmdf_d35_path = _parse_str(line_number, value)
    elif key == ':xmdf_d50_path':
        control.xmdf_d50_path = _parse_str(line_number, value)
    elif key == ':xmdf_d90_path':
        control.xmdf_d90_path = _parse_str(line_number, value)
    # elif key == ':xmdf_dep_path':
    #     control.xmdf_dep_path = _parse_str(line_number, value)
    # elif key == ':xmdf_grid_path':
    #     control.xmdf_grid_path = _parse_str(line_number, value)
    # elif key == ':xmdf_m3d_path':
    #     control.xmdf_m3d_path = _parse_str(line_number, value)
    # elif key == ':xmdf_sal_path':
    #     control.xmdf_sal_path = _parse_str(line_number, value)
    # elif key == ':xmdf_temp_path':
    #     control.xmdf_temp_path = _parse_str(line_number, value)
    # elif key == ':xmdf_u_path':
    #     control.xmdf_u_path = _parse_str(line_number, value)
    elif key == ':xmdf_vel_path':
        control.xmdf_vel_path = _parse_str(line_number, value)
    # elif key == ':xmdf_v_path':
    #     control.xmdf_v_path = _parse_str(line_number, value)
    elif key == ':xmdf_wse_path':
        control.xmdf_wse_path = _parse_str(line_number, value)
    # elif key == ':xmdf_w_path':
    #     control.xmdf_w_path = _parse_str(line_number, value)
    # elif key == ':bottom_start_flow':
    #     control.bottom_start_flow = _parse_datetime(line_number, value)
    elif key == ':start_flow':
        control.start_flow = _parse_datetime(line_number, value)
    elif key == ':start_run':
        control.start_run = _parse_datetime(line_number, value)
    elif key == ':start_trap':
        control.start_trap = _parse_datetime(line_number, value)
    elif key == ':start_waves':
        control.start_waves = _parse_datetime(line_number, value)
    elif key == ':stop_run':
        control.stop_run = _parse_datetime(line_number, value)
    elif key == ':stop_trap':
        control.stop_trap = _parse_datetime(line_number, value)
    elif key == ':advection_method':
        control.advection_method = _parse_advection_method(line_number, value)
    # elif key == ':bottom_flow_format':
    #     control.bottom_flow_format = _parse_bottom_flow_format(line_number, value)
    elif key == ':centroid_method':
        control.centroid_method = _parse_centroid_method(line_number, value)
    elif key == ':eulerian_method':
        control.eulerian_method = _parse_eulerian_method(line_number, value)
    elif key == ':eulerian_sed_trans':
        control.eulerian_transport_method = _parse_eulerian_transport_method(line_number, value)
    elif key == ':flow_format':
        control.flow_format = _parse_flow_format(line_number, value)
    elif key == ':mesh_format':
        control.mesh_format = _parse_mesh_format(line_number, value)
    elif key == ':numerical_scheme':
        control.numerical_scheme = _parse_numerical_scheme(line_number, value)
    elif key == ':sediment_format':
        control.sediment_format = _parse_sediment_format(line_number, value)
    elif key == ':velocity_method':
        control.velocity_method = _parse_velocity_method(line_number, value)
    # elif key == ':wave_format':
    #     control.wave_format = _parse_wave_format(line_number, value)
    elif key.startswith(':emrl'):
        pass  # Obsolete keyword. Both PTM and XMS ignore it now.
    elif key == ':end_data':
        return False
    else:
        raise ParseError(line_number, 1, f'Unrecognized keyword: {key}')

    return True


def _parse_bool(line_number: int, value: str) -> bool:
    """Parse a bool."""
    if not value:
        return True

    raise ParseError(line_number, 1, 'Flag keyword has value.')


def _parse_float(line_number: int, value: str) -> float:
    """Parse a float."""
    if not value:
        raise ParseError(line_number, 1, 'Real keyword needs value.')

    try:
        return float(value)
    except ValueError:
        raise ParseError(line_number, 1, 'Real keyword has non-real value.')


def _parse_int(line_number: int, value: str) -> int:
    """Parse an int."""
    if not value:
        raise ParseError(line_number, 1, 'Integer keyword needs value.')

    try:
        return int(value)
    except ValueError:
        raise ParseError(line_number, 1, 'Integer keyword has non-integer value.')


def _parse_str(line_number: int, value: str) -> str:
    """Parse a string."""
    if not value:
        raise ParseError(line_number, 1, 'Character parameter needs value.')

    return value


def _parse_datetime(line_number: int, value: str) -> datetime:
    """Parse a datetime."""
    try:
        pieces = [int(piece) for piece in value.split()]
        year, month, day, hour, minute, second = pieces
        return datetime(year=year, month=month, day=day, hour=hour, minute=minute, second=second)
    except (ValueError, TypeError):
        raise ParseError(line_number, 1, 'Date/time parameter needs a date and time in `YYYY MM DD HH MM SS` format.')


def _parse_advection_method(line_number: int, value: str) -> AdvectionMethod:
    """Parse an AdvectionMethod enum."""
    value = value.lower()
    if value == '1d':
        return AdvectionMethod.one_d
    elif value == '2d':
        return AdvectionMethod.two_d
    elif value == '3d':
        return AdvectionMethod.three_d
    elif value == 'q3d':
        return AdvectionMethod.q_three_d
    else:
        raise ParseError(line_number, 1, ':ADVECTION_METHOD keyword value must be either 1D, 2D, or 3D.')


# def _parse_bottom_flow_format(line_number: int, value: str) -> BottomFlowFormat:
#     """Parse a BottomFlowFormat enum."""
#     value = value.lower()
#     if value == 'adcirc':
#         return BottomFlowFormat.adcirc
#     elif value == 'xmdf':
#         return BottomFlowFormat.xmdf
#     else:
#         raise ParseError(line_number, ':BOTTOM_FLOW_FORMAT keyword value must be either ADCIRC or XMDF.')


def _parse_centroid_method(line_number: int, value: str) -> CentroidMethod:
    """Parse a CentroidMethod enum."""
    value = value.lower()
    if value == 'rouse':
        return CentroidMethod.rouse
    elif value == 'pie':
        return CentroidMethod.rouse
    elif value == 'van_rijn':
        return CentroidMethod.van_rijn
    else:
        raise ParseError(line_number, 1, ':CENTROID_METHOD keyword value must be either ROUSE, PIE, or VAN_RIJN.')


def _parse_eulerian_method(line_number: int, value: str) -> EulerianMethod:
    """Parse a EulerianMethod enum."""
    value = value.lower()
    if value == 'ptm':
        return EulerianMethod.ptm
    elif value == 'van_rijn':
        return EulerianMethod.van_rijn
    else:
        raise ParseError(line_number, 1, ':EULERIAN_METHOD keyword value must be either PTM or VAN_RIJN.')


def _parse_eulerian_transport_method(line_number: int, value: str) -> EulerianTransportMethod:
    """Parse a EulerianTransportMethod enum."""
    value = value.lower()
    if value == 'soulsby-van_rijn':
        return EulerianTransportMethod.soulsby_van_rijn
    elif value == 'van_rijn':
        return EulerianTransportMethod.van_rijn
    elif value == 'lund':
        return EulerianTransportMethod.lund
    elif value == 'camenen_larson':
        return EulerianTransportMethod.camenen_larson
    else:
        raise ParseError(
            line_number, 1,
            ':EULERIAN_SED_TRANS keyword value must be either soulsby-van_rijn, van_rign, lund, or camenen_larson.'
        )


def _parse_flow_format(line_number: int, value: str) -> FlowFormat:
    """Parse a FlowFormat enum."""
    value = value.lower()
    if value == 'adcirc':
        return FlowFormat.adcirc_ascii
    elif value == 'xmdf':
        return FlowFormat.adcirc_xmdf
    elif value == 'cms-2d':
        return FlowFormat.cmsflow_2d_single
    elif value == 'cms-m2d':
        return FlowFormat.cmsflow_2d_single
    elif value == 'm2d':
        return FlowFormat.cmsflow_2d_single
    elif value == 'cms-2d-multi':
        return FlowFormat.cmsflow_2d_multi
    elif value == 'adh':
        return FlowFormat.adh
    else:
        raise ParseError(
            line_number, 1,
            ':FLOW_FORMAT keyword value must be either ADCIRC, XMDF, CMS-2D, CMS-M2d, M2D, CMS-2D-MULTI, or ADH.'
        )


def _parse_mesh_format(line_number: int, value: str) -> MeshFormat:
    """Parse a MeshFormat enum."""
    value = value.lower()
    if value == 'adcirc':
        return MeshFormat.adcirc
    elif value == 'cms-m2d':
        return MeshFormat.cms_2d
    elif value == 'cms-2d':
        return MeshFormat.cms_2d
    else:
        raise ParseError(line_number, 1, ':MESH_FORMAT keyword value must be either ADCIRC, CMS-M2d, or CMS-2D.')


def _parse_numerical_scheme(line_number: int, value: str) -> NumericalScheme:
    """Parse a NumericalScheme enum."""
    value = value.lower()
    if value == '2':
        return NumericalScheme.two
    elif value == '4':
        return NumericalScheme.four
    else:
        raise ParseError(line_number, 1, ':NUMERICAL_SCHEME keyword value must be either 2 or 4.')


def _parse_sediment_format(line_number: int, value: str) -> SedimentFormat:
    """Parse a SedimentFormat enum."""
    value = value.lower()
    if value == 'adcirc':
        return SedimentFormat.adcirc
    elif value == 'xmdf_property':
        return SedimentFormat.xmdf_dataset
    elif value == 'xmdf_dataset':
        return SedimentFormat.xmdf_dataset
    elif value == 'xmdf':
        return SedimentFormat.xmdf_dataset
    elif value == 'm2d':
        return SedimentFormat.m2d
    elif value == 'cms-2d':
        return SedimentFormat.m2d
    elif value == 'uniform':
        return SedimentFormat.uniform
    else:
        message = (
            ':SEDIMENT_FORMAT keyword value must be either '
            'ADCIRC, XMDF_PROPERTY, XMDF_DATASET, XMDF, M2D, CMS-2D, or UNIFORM.'
        )
        raise ParseError(line_number, 1, message)


def _parse_velocity_method(line_number: int, value: str) -> VelocityMethod:
    """Parse a VelocityMethod enum."""
    value = value.lower()
    if value == '2d (log)':
        return VelocityMethod.two_d_log
    elif value == '2d (logarithmic)':
        return VelocityMethod.two_d_log
    elif value == 'logarithmic':
        return VelocityMethod.two_d_log
    elif value == 'two-point':
        return VelocityMethod.two_d_two_point
    elif value == 'two_point':
        return VelocityMethod.two_d_two_point
    elif value == '2d (two-point)':
        return VelocityMethod.two_d_two_point
    elif value == '2d (two_point)':
        return VelocityMethod.two_d_two_point
    elif value == '2d (uniform)':
        return VelocityMethod.two_d_uniform
    elif value == 'uniform':
        return VelocityMethod.two_d_uniform
    elif value == '3d':
        return VelocityMethod.three_d_sigma
    elif value == '3ds':
        return VelocityMethod.three_d_sigma
    elif value == 'fully-3d':
        return VelocityMethod.three_d_sigma
    elif value == '3dz':
        return VelocityMethod.three_d_z
    else:
        raise ParseError(
            line_number, 1, ':VELOCITY_METHOD keyword value must be either LOGARITHMIC, TWO-POINT, UNIFORM, 3D, or 3DZ.'
        )


# def _parse_wave_format(line_number: int, value: str) -> WaveFormat:
#     """Parse a WaveFormat enum."""
#     value = value.lower()
#     if value == 'stwave':
#         return WaveFormat.stwave
#     elif value == 'wabed':
#         return WaveFormat.wabed
#     elif value == 'xmdf':
#         return WaveFormat.xmdf
#     elif value == 'cms-wave':
#         return WaveFormat.cms_wave
#     else:
#         raise ParseError(line_number, ':WAVE_FORMAT keyword value must be either STWAVE, WABED, XMDF, or CMS-WAVE.')
