"""VectorsFromScalarsTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import logging

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.constraint import Grid
from xms.constraint.ugrid_activity import CellToPointActivityCalculator
from xms.datasets import vectors as xmdv
from xms.datasets.dataset_reader import DatasetReader
from xms.datasets.dataset_writer import DatasetWriter

# 4. Local modules


def vectors_from_scalars(x_reader: DatasetReader, y_reader: DatasetReader, input_grid_x: Grid | None,
                         input_grid_y: Grid | None, input_type: str, output_dataset_name: str,
                         logger: logging.Logger) -> DatasetWriter:
    """Converts scalar datasets to a vector dataset.

    Args:
        x_reader (DatasetReader): The DatasetReader for the x components (or magnitudes).
        y_reader (DatasetReader): The DatasetReader for the y components (or directions).
        input_grid_x (Grid or None): The Grid for the x input data.
        input_grid_y (Grid or None): The Grid for the y input data.
        input_type (str): The type of the inputs (either: 'Vx and Vy' or 'Magnitude and Direction').
        output_dataset_name (str): The name of the output dataset.
        logger (logging.Logger): The logger used to display progress output to the user.

    Returns:
        (DatasetWriter): The DatasetWriter for the resulting vector dataset.
    """
    # Extract the components, time step at a time
    builder = _setup_output_dataset_builder(x_reader, y_reader, input_grid_x, input_grid_y, output_dataset_name)
    time_count = len(x_reader.times)
    for tsidx in range(time_count):
        logger.info(f'Processing time step {tsidx + 1} of {time_count}...')
        activity = None
        # magnitude if doing mag/dir conversion
        if builder.activity_calculator is not None:
            x_data = x_reader.values[tsidx][:]
            x_activity = x_reader.activity[tsidx][:]
        else:
            x_data, x_activity = x_reader.timestep_with_activity(tsidx)
        # direction if doing mag/dir conversion
        if builder.activity_calculator is not None:
            y_data = y_reader.values[tsidx][:]
            y_activity = y_reader.activity[tsidx][:]
        else:
            y_data, y_activity = y_reader.timestep_with_activity(tsidx)
        if x_activity is not None and y_activity is not None:
            activity = x_activity | y_activity
            builder.use_activity_as_null = True

        if input_type != 'Vx and Vy':
            # Convert magnitude and direction to Vx/Vy.
            x_data, y_data = xmdv.magdir_to_vx_vy(x_data, y_data)

        data = np.vstack((x_data, y_data)).T
        builder.append_timestep(x_reader.times[tsidx], data, activity)

    #
    logger.info('Writing output vector dataset to XMDF file...')
    builder.appending_finished()

    return builder


def _setup_output_dataset_builder(x_reader: DatasetReader, y_reader: DatasetReader, input_grid_x: Grid,
                                  input_grid_y: Grid, output_dataset_name: str):
    """Set up dataset builders for selected tool outputs.

    Args:
        x_reader (DatasetReader): The DatasetReader for the x components (or magnitudes).
        y_reader (DatasetReader): The DatasetReader for the y components (or directions).
        input_grid_x (Grid or None): The Grid for the x input data.
        input_grid_y (Grid or None): The Grid for the y input data.
        output_dataset_name (str): The name of the output dataset.
    """
    # Create a place for the output dataset file
    builder = DatasetWriter(
        name=output_dataset_name,
        geom_uuid=x_reader.geom_uuid,
        num_components=2,
        ref_time=x_reader.ref_time,
        time_units=x_reader.time_units,
        null_value=x_reader.null_value,
        location=x_reader.location
    )

    # Set activity calculator
    x_cell_activity = False
    y_cell_activity = False

    if x_reader.activity and x_reader.values:
        x_cell_activity = x_reader.values.shape != x_reader.activity.shape

    if y_reader.activity and y_reader.values:
        y_cell_activity = y_reader.values.shape != y_reader.activity.shape

    if x_cell_activity and y_cell_activity:
        # Nodal dataset values with cell activity
        input_ugrid = input_grid_x.ugrid
        builder.activity_calculator = CellToPointActivityCalculator(input_ugrid)
    elif x_cell_activity:
        input_ugrid = input_grid_x.ugrid
        x_reader.activity_calculator = CellToPointActivityCalculator(input_ugrid)
    elif y_cell_activity:
        input_ugrid = input_grid_y.ugrid
        y_reader.activity_calculator = CellToPointActivityCalculator(input_ugrid)

    return builder
