"""AdvectiveTimestep Algorithm."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from logging import Logger

# 2. Third party modules
import numpy

# 3. Aquaveo modules
from xms.constraint import Grid
from xms.constraint.ugrid_activity import CellToPointActivityCalculator
from xms.datasets.dataset_reader import DatasetReader
from xms.datasets.dataset_writer import DatasetWriter
from xms.mesher.meshing import mesh_utils

# 4. Local modules
from xms.tool.utilities.dataset_tool import set_builder_activity_flags


def advective_time_step(dataset_name: str, input_dset: DatasetReader, dset_ugrid: Grid, courant_number: float,
                        logger: Logger) -> DatasetWriter:
    """Computes advective time step dataset.

    Args:
        dataset_name: Name of the output dataset.
        input_dset: Dataset reader for the input dataset.
        dset_ugrid: The input dataset's unconstrained grid.
        courant_number: The courant number used in the simulation.
        logger: Logger for user output.

    Returns:
        The DatasetWriter for the output dataset
    """
    # Create a place for the output dataset file
    builder = DatasetWriter()
    builder.name = dataset_name
    builder.geom_uuid = input_dset.geom_uuid
    builder.ref_time = input_dset.ref_time
    builder.time_units = input_dset.time_units
    builder.null_value = input_dset.null_value

    # Calculate the advective time step dataset
    if input_dset.activity and input_dset.values:
        if input_dset.values.shape != input_dset.activity.shape:
            # Nodal dataset values with cell activity
            input_dset.activity_calculator = CellToPointActivityCalculator(dset_ugrid)
            builder.activity_calculator = CellToPointActivityCalculator(dset_ugrid)

    # Calculate the size function and set the dataset values default
    logger.info('Calculating size function...')
    size_func = mesh_utils.size_function_from_edge_lengths(dset_ugrid)
    size_func = numpy.array(size_func)

    # Loop on the timesteps
    time_count = len(input_dset.times)
    for tsidx in range(time_count):
        logger.info(f'Processing time step {tsidx + 1} of {time_count}...')
        data, activity = input_dset.timestep_with_activity(tsidx)
        set_builder_activity_flags(activity, builder)

        # Calculate a value for each data value, after defaulting to the dataset null value
        magnitude = numpy.sqrt((data ** 2).sum(-1))
        new_data = numpy.full([len(data)], numpy.nan)
        for i in range(len(new_data)):
            if magnitude[i] > 0.0:
                new_data[i] = courant_number * size_func[i] / magnitude[i]
        # new_data[magnitude > 0.0] = self._courant_number * size_func / magnitude

        # Append the values to the time step
        builder.append_timestep(input_dset.times[tsidx], new_data, activity)

    # Write output advective time step dataset to XMDF file
    logger.info('Writing output advective time step dataset to XMDF file...')
    builder.appending_finished()

    return builder
