"""Module for model definitions."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"
__all__ = ['simulation_model', 'source_model', 'trap_model', 'StopMode']

# 1. Standard Python modules
from datetime import datetime

# 2. Third party modules

# 3. Aquaveo modules
from xms.gmi.data.generic_model import GenericModel, Section
from xms.guipy.time_format import datetime_to_string
from xms.tool_core.table_definition import (
    DateTimeColumnType, FloatColumnType, IntColumnType, StringColumnType, TableDefinition
)

# 4. Local modules


class StopMode:
    """How the model defines the stop time."""
    date_time = 'Date/Time'
    duration = 'Duration'


def simulation_model() -> Section:
    """The model used for the simulation model control."""
    model = GenericModel()
    section = model.global_parameters

    group = section.add_group('time', 'Time')
    group.add_date_time('start_run', 'Simulation start', default=None)
    options = [StopMode.date_time, StopMode.duration]
    stop_mode = group.add_option('stop_mode', 'Stop based on', default=options[0], options=options)
    stop_run = group.add_date_time('stop_run', 'Simulation end', default=None)
    stop_run.add_dependency(stop_mode, {StopMode.date_time: True, StopMode.duration: False})
    duration = group.add_float('duration', 'Duration', default=0.0)
    duration.add_dependency(stop_mode, {StopMode.date_time: False, StopMode.duration: True})
    group.add_float('time_step', 'Time step (seconds)', default=1.0)
    group.add_date_time('start_waves', 'Wave start time', default=None)
    group.add_float('wave_step', 'Wave time step (seconds)', default=0.0)
    group.add_integer('output_inc', 'Particle Output (time steps)', default=1)
    group.add_integer('mapping_inc', 'Mapping Output (time steps)', default=1)
    provide_hydro = group.add_boolean('have_start_flow', 'Provide hydrodynamic start time', default=False)
    hydro_start = group.add_date_time('start_flow', 'Hydrodynamic start time', default=None)
    hydro_start.add_dependency(provide_hydro, {True: True, False: False})
    group.add_integer('grid_update', 'Shears, Bedforms, and Mobility Update (time steps)', default=1)
    group.add_integer('flow_update', 'Flow and elevation update (time steps)', default=1)
    group.add_date_time('start_trap', 'Traps start time', default=None)
    group.add_date_time('stop_trap', 'Traps end time', default=None)
    group.add_boolean('last_step_trap', 'Last time step trap', default=False)

    group = section.add_group('files', 'Files')
    group.add_text('flow_file_xmdf', 'XMDF Flow File', default='FLOW_FILE')
    options = ['ADCIRC ASCII', 'ADCIRC XMDF', 'CMS-Flow 2D Single', 'CMS-Flow 2D Multi', 'ADH']
    group.add_option('flow_format', 'Flow File Format', default=options[0], options=options)
    group.add_text('xmdf_vel_path', 'XMDF Velocity Dataset Path', default='Datasets/Velocity')
    group.add_text('xmdf_wse_path', 'XMDF WSE Dataset Path', default='Datasets/Water Surface Elevation')
    group.add_text('bc_file', 'Boundary Conditions', default='boundary conditions')
    group.add_text('mesh_file', 'Mesh File Name', default='MESH')
    options = ['ADCIRC', 'CMS-Flow 2D', 'XMDF Dataset', 'Uniform']
    group.add_option('mesh_format', 'Mesh Format', default=options[0], options=options)
    group.add_text('neighbor_file', 'Neighbor File Name', default='neighbors')
    options = ['ADCIRC', 'M2D', 'XMDF Dataset', 'Uniform']
    group.add_option('sediment_format', 'Sediment Format', default=options[0], options=options)
    group.add_text('sediment_file', 'Sediment File Name', default='sediments')
    group.add_text('source_file', 'Source File Name', default='sources')
    group.add_text('trap_file', 'Trap File Name', default='traps')
    group.add_text('output_prefix', 'Output Prefix', default='Output')

    group = section.add_group('computations', 'Computations')
    options = ['1D', '2D', '3D', 'Q3D']
    group.add_option('advection_method', 'Advection', default=options[0], options=options)
    options = ['Rouse', 'Van Rijn']
    group.add_option('centroid_method', 'Centroid', default=options[0], options=options)
    options = ['By grain size', 'By weight']
    group.add_option('distribution', 'Distribution', default=options[0], options=options)
    options = ['PTM', 'Van Rijn']
    group.add_option('eulerian_method', 'Eulerian Method', default=options[0], options=options)
    options = ['2D Logarithmic', '2D Uniform', '2D Two-Point', '3DS', '3DZ']
    group.add_option('velocity_method', 'Velocity Method', default=options[0], options=options)
    options = ['Soulsby-Van Rijn', 'Van Rijn', 'Lund', 'Camenen Larson']
    group.add_option('eulerian_transport_method', 'Eulerian Transport Method', default=options[0], options=options)
    options = ['2', '4']
    group.add_option('numerical_scheme', 'Numerical Scheme', default=options[0], options=options)
    group.add_float('bed_porosity', 'Bed porosity', default=0.0)
    group.add_float('rhos', 'Bed density', default=0.0)
    group.add_float('min_depth', 'Minimum depth (meters)', default=0.0)
    group.add_float('temperature', 'Temperature (℃)', default=0.0)
    group.add_float('salinity', 'Salinity (ppt)', default=0.0)
    group.add_float('etmin', 'Horizontal minimum diffusion coefficient', default=0.0)
    group.add_float('ket', 'Horizontal turbulent diffusion scalar', default=0.0)
    group.add_float('evmin', 'Vertical minimum diffusion coefficient', default=0.0)
    group.add_float('kev', 'Vertical turbulent diffusion scalar', default=0.0)
    group.add_float('kew', 'Wave diffusion scalar', default=0.0)
    group.add_boolean('currents', 'Currents', default=False)
    group.add_boolean('morphology', 'Morphology', default=False)
    group.add_boolean('neutrally_buoyant', 'Neutrally buoyant particles', default=False)
    group.add_boolean('bedforms', 'Bedforms', default=False)
    group.add_boolean('hiding_exposure', 'Hiding and exposure', default=False)
    group.add_boolean('bed_interaction', 'Particle-bed interaction', default=False)
    group.add_boolean('turbulent_shear', 'Turbulent shear', default=False)
    group.add_boolean('source_to_datum', 'Source and trap Z-value relative to datum', default=False)
    group.add_boolean('residence_calc', 'Residency (polygon trap required)', default=False)
    group.add_boolean('wave_mass_transport', 'Wave mass transport', default=False)

    group = section.add_group('output', 'Output')
    group.add_boolean('xmdf_compressed', 'Compress XMDF files', default=False)
    group.add_boolean('tecplot_maps', 'Tecplot© map data', default=False)
    group.add_boolean('tecplot_parcels', 'Tecplot© particle data', default=False)
    group.add_boolean('paths', 'Tecplot© path data', default=False)
    group.add_boolean('population_record', 'Tecplot© path data', default=False)
    group.add_boolean('bed_level_mapping', 'Bed evolution', default=False)
    group.add_boolean('bed_level_change_mapping', 'Bed level change', default=False)
    group.add_boolean('flow_mapping', 'Flow conditions', default=False)
    group.add_boolean('mobility_mapping', 'Mobility of native sediments', default=False)
    group.add_boolean('bedform_mapping', 'Native sediment bedforms', default=False)
    group.add_boolean('transport_mapping', 'Potential sediment transport rate', default=False)
    group.add_boolean('shear_stress_mapping', 'Shear stress', default=False)
    group.add_boolean('wave_mapping', 'Wave parameters', default=False)
    group.add_boolean('fall_velocity_output', 'Critical shear', default=False)
    group.add_boolean('density_output', 'Density', default=False)
    group.add_boolean('tau_cr_output', 'Fall velocity', default=False)
    group.add_boolean('grain_size_output', 'Grain size', default=False)
    group.add_boolean('height_output', 'Height above the bed', default=False)
    group.add_boolean('parcel_mass_output', 'Mass', default=False)
    group.add_boolean('mobility_output', 'Mobility', default=False)
    group.add_boolean('source_output', 'Source', default=False)
    group.add_boolean('state_output', 'State', default=False)
    group.add_boolean('flow_output', 'Velocity components', default=False)

    return section


def source_model() -> GenericModel:
    """The model used for the Sources coverage."""
    model = GenericModel(exclusive_point_conditions=True, exclusive_arc_conditions=True)

    # Point parameters
    section = model.point_parameters

    group = section.add_group('instant', 'Instant Mass Source')
    group.add_integer('id', 'Source ID', default=0, low=0)
    group.add_text('label', 'Name', default='Instant Mass Source', required=True)
    definition = [
        DateTimeColumnType(header='Date/Time', default=datetime(year=1950, month=1, day=1)),
        FloatColumnType(header='X\n(m)', default=0.0),
        FloatColumnType(header='Y\n(m)', default=0.0),
        FloatColumnType(header='Elevation\n(m)', default=0.0),
        FloatColumnType(header='Parcel Mass\n(kg)', default=2.0),
        FloatColumnType(header='Horiz. Radius\n(m)', default=1.0),
        FloatColumnType(header='Vert. Radius\n(m)', default=1.0),
        FloatColumnType(header='Mass\n(kg)', default=0.05),
        FloatColumnType(header='Median Grain Size\n(mm)', default=0.1),
        FloatColumnType(header='Standard Deviation\n(Phi-units)', default=0.8),
        FloatColumnType(header='Density\n(kg/m³', default=2650.0),
        FloatColumnType(header='Fall Velocity\n(m/s)', default=-1.0),
        FloatColumnType(header='Critical Shear Initiation\n(N/m²)', default=-1.0),
        FloatColumnType(header='Critical Shear Deposition\n(N/m²)', default=-1.0),
    ]
    row = [
        column.default if column.dtype != 'datetime64[ns]' else datetime_to_string(column.default)
        for column in definition
    ]
    defaults = [row]
    table_def = TableDefinition(definition)
    group.add_table('instructions', 'Instructions', default=defaults, table_definition=table_def, display_table=True)

    group = section.add_group('point', 'Point Mass Source')
    group.add_integer('id', 'Source ID', default=0, low=0)
    group.add_text('label', 'Name', default='Point Mass Source', required=True)
    definition = [
        DateTimeColumnType(header='Date/Time', default=datetime(year=1950, month=1, day=1)),
        FloatColumnType(header='X\n(m)', default=0.0),
        FloatColumnType(header='Y\n(m)', default=0.0),
        FloatColumnType(header='Elevation\n(m)', default=0.0),
        FloatColumnType(header='Parcel Mass\n(kg)', default=2.0),
        FloatColumnType(header='Horiz. Radius\n(m)', default=1.0),
        FloatColumnType(header='Vert. Radius\n(m)', default=1.0),
        FloatColumnType(header='Rate\n(kg)', default=0.05),
        FloatColumnType(header='Median Grain Size\n(mm)', default=0.1),
        FloatColumnType(header='Standard Deviation\n(Phi-units)', default=0.8),
        FloatColumnType(header='Density\n(kg/m³', default=2650.0),
        FloatColumnType(header='Fall Velocity\n(m/s)', default=-1.0),
        FloatColumnType(header='Critical Shear Initiation\n(N/m²)', default=-1.0),
        FloatColumnType(header='Critical Shear Deposition\n(N/m²)', default=-1.0),
    ]
    table_def = TableDefinition(definition)
    row = [
        column.default if column.dtype != 'datetime64[ns]' else datetime_to_string(column.default)
        for column in definition
    ]
    defaults = [row]
    group.add_table('instructions', 'Instructions', default=defaults, table_definition=table_def, display_table=True)

    group = section.add_group('vertical_line', 'Vertical Line Mass Source')
    choices = ['Bed Datum', 'Depth Distributed', 'Source Datum']
    definition = [
        StringColumnType(header='Segment name', default='Vertical Line'),
        IntColumnType(header='Segment ID', default=1),
        StringColumnType(header='Datum', choices=choices, default='Bed Datum'),
        DateTimeColumnType(header='Date/Time', default=datetime(year=1950, month=1, day=1)),
        FloatColumnType(header='X\n(m)', default=0.0),
        FloatColumnType(header='Y\n(m)', default=0.0),
        FloatColumnType(header='Bottom Elev.\n(m)', default=2.0),
        FloatColumnType(header='Top Elev.\n(m)', default=2.0),
        FloatColumnType(header='Parcel Mass\n(kg)', default=1.0),
        FloatColumnType(header='Radius\n(m)', default=1.0),
        FloatColumnType(header='Rate\n(kg)', default=0.05),
        FloatColumnType(header='Median Grain Size\n(mm)', default=0.1),
        FloatColumnType(header='Standard Deviation\n(Phi-units)', default=0.8),
        FloatColumnType(header='Density\n(kg/m³', default=2650.0),
        FloatColumnType(header='Fall Velocity\n(m/s)', default=-1.0),
        FloatColumnType(header='Critical Shear Initiation\n(N/m²)', default=-1.0),
        FloatColumnType(header='Critical Shear Deposition\n(N/m²)', default=-1.0),
    ]
    table_def = TableDefinition(definition)
    row = [
        column.default if column.dtype != 'datetime64[ns]' else datetime_to_string(column.default)
        for column in definition
    ]
    defaults = [row]
    group.add_table('instructions', 'Instructions', default=defaults, table_definition=table_def, display_table=True)

    # Arc parameters
    section = model.arc_parameters

    group = section.add_group('horizontal_line', 'Horizontal Line Source')
    group.add_integer('id', 'Source ID', default=0, low=0)
    group.add_text('label', 'Name', default='Horizontal Line Source', required=True)

    definition = [
        DateTimeColumnType(header='Date/Time', default=datetime(year=1950, month=1, day=1)),
        FloatColumnType(header='X1\n(m)', default=0.0),
        FloatColumnType(header='Y1\n(m)', default=0.0),
        FloatColumnType(header='X2\n(m)', default=0.0),
        FloatColumnType(header='Y2\n(m)', default=0.0),
        FloatColumnType(header='Elevation\n(m)', default=0.0),
        FloatColumnType(header='Parcel Mass\n(kg)', default=2.0),
        FloatColumnType(header='Radius\n(m)', default=1.0),
        FloatColumnType(header='Rate\n(kg/(m*s))', default=0.05),
        FloatColumnType(header='Median Grain Size\n(mm)', default=0.1),
        FloatColumnType(header='Standard Deviation\n(Phi-units)', default=0.8),
        FloatColumnType(header='Density\n(kg/m³', default=2650.0),
        FloatColumnType(header='Fall Velocity\n(m/s)', default=-1.0),
        FloatColumnType(header='Critical Shear Initiation\n(N/m²)', default=-1.0),
        FloatColumnType(header='Critical Shear Deposition\n(N/m²)', default=-1.0),
    ]
    table_def = TableDefinition(definition)
    row = [
        column.default if column.dtype != 'datetime64[ns]' else datetime_to_string(column.default)
        for column in definition
    ]
    defaults = [row]
    group.add_table('instructions', 'Instructions', default=defaults, table_definition=table_def, display_table=True)

    # Polygon parameters
    section = model.polygon_parameters
    group = section.add_group('polygon', 'Horizontal Polygon Source')

    group.add_integer('id', 'Source ID', default=0, low=0)
    group.add_text('label', 'Name', default='Horizontal Line Source', required=True)

    definition = [
        DateTimeColumnType(header='Date/Time', default=datetime(year=1950, month=1, day=1)),
        FloatColumnType(header='Elevation\n(m)', default=5.0),
        FloatColumnType(header='Parcel Mass\n(kg)', default=200.0),
        FloatColumnType(header='Vert. Radius\n(m)', default=1.0),
        FloatColumnType(header='Rate\n(kg/(m*s))', default=0.05),
        FloatColumnType(header='Median Grain Size\n(mm)', default=0.0001),
        FloatColumnType(header='Standard Deviation\n(Phi-units)', default=0.8),
        FloatColumnType(header='Density\n(kg/m³', default=2650.0),
        FloatColumnType(header='Fall Velocity\n(m/s)', default=-1.0),
        FloatColumnType(header='Critical Shear Initiation\n(N/m²)', default=-1.0),
        FloatColumnType(header='Critical Shear Deposition\n(N/m²)', default=-1.0),
    ]
    table_def = TableDefinition(definition)
    row = [
        column.default if column.dtype != 'datetime64[ns]' else datetime_to_string(column.default)
        for column in definition
    ]
    defaults = [row]
    group.add_table('instructions', 'Instructions', default=defaults, table_definition=table_def, display_table=True)

    return model


def trap_model() -> GenericModel:
    """The model used for the Traps coverage."""
    model = GenericModel()

    section = model.arc_parameters
    group = section.add_group('horizontal_line', 'Horizontal Line Trap')
    group.add_integer('id', 'ID', default=1, low=1)
    group.add_text('name', 'Name', required=True)
    group.add_float('bottom', 'Bottom', default=0.0)
    group.add_float('top', 'Top', default=0.0)
    group.add_boolean('is_open', 'Open Trap (particles can leave trap)', default=False)
    group.add_boolean('is_single', 'Single Trap (count particles once per simulation)', default=True)
    options = ['Decreasing x-coordinate', 'Either Direction', 'Increasing x-coordinate']
    group.add_option('direction', 'Trap Direction', options=options, default=options[2])

    section = model.polygon_parameters
    group = section.add_group('polygon', 'Horizontal Polygon Trap')
    group.add_integer('id', 'ID', default=1, low=1)
    group.add_text('name', 'Name', required=True)
    group.add_float('bottom', 'Bottom', default=0.0)
    group.add_float('top', 'Top', default=0.0)
    group.add_boolean('is_open', 'Open Trap (particles can leave trap)', default=False)
    group.add_boolean('is_single', 'Single Trap (count particles once per simulation)', default=True)

    return model
