# 1. Standard python modules

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.grid.ugrid import UGrid

# 4. Local modules
from . import _xmsconstraint


ACTIVITY_NULL_VALUE = 'null value'
ACTIVITY_CELL_ARRAY = 'cells'
ACTIVITY_POINT_ARRAY = 'points'
ACTIVITY_NONE = 'none'

MAPPING_CELLS = 'cells'
MAPPING_POINTS = 'points'


def active_points_from_cells(grid, cell_activity):
    """Create point activity from cell activity.

    Args:
        grid (UGrid): The grid.
        cell_activity (Iterable[bool]): The cell activity.

    Returns:
        (Iterable[bool]): The point activity.
    """
    return np.array(_xmsconstraint.ugrid.ugActivePointsFromCells(grid._instance, cell_activity),
                    dtype=np.uint8)


def active_cells_from_points(grid, point_activity):
    """Create cell activity from point activity.

    Args:
        grid (UGrid): The grid.
        point_activity (Iterable[bool]): The point activity.

    Returns:
        (Iterable[bool]): The cell activity.
    """
    return np.array(_xmsconstraint.ugrid.ugActiveCellsFromPoints(grid._instance, point_activity),
                    dtype=np.dtype('int8'))


def dataset_mapping_type(grid, values):
    """Get the mapping type of dataset values.

    Args:
        grid (UGrid): The dataset's grid.
        values (Iterable[float]): The timestep values.

    Returns:
        (str): Mapping type (MAPPING_POINTS or MAPPING_CELLS).
    """
    if len(values) == grid.cell_count:
        return MAPPING_CELLS
    else:
        return MAPPING_POINTS


def dataset_activity_type(ugrid: UGrid, activity, null_value):
    """Get the activity type of dataset values.

    Args:
        ugrid (UGrid): The dataset's grid.
        activity (Optional[np.array]): The timestep activity. An empty array means all are active.
        null_value (Optional[float]): The null value if using value activity.

    Returns:
        (str): Activity type (ACTIVITY_NONE, ACTIVITY_CELL_ARRAY, ACTIVITY_POINT_ARRAY, ACTIVITY_NULL_VALUE).
    """
    if null_value is not None:
        return ACTIVITY_NULL_VALUE
    elif activity is None:
        return ACTIVITY_NONE
    elif ugrid.cell_count == len(activity):
        return ACTIVITY_CELL_ARRAY
    elif ugrid.point_count == len(activity):
        return ACTIVITY_POINT_ARRAY
    return ACTIVITY_NONE


def values_with_nans(ugrid, values, dataset_activity, null_value):
    """Calculate activity array for dataset values.

    Args:
        ugrid (UGrid): The dataset's UGrid.
        values (np.array[float]): The timestep values.
        dataset_activity (np.array[int]): The timestep activity from the dataset.
        null_value (Optional[float]): The null value if using value activity.

    Returns:
        (np.array[float]): The values with inactive nans.
    """
    values = np.copy(values)
    activity_type = dataset_activity_type(ugrid, dataset_activity, null_value)
    mapping_type = dataset_mapping_type(ugrid, values)
    if activity_type == ACTIVITY_CELL_ARRAY or activity_type == ACTIVITY_POINT_ARRAY:
        if mapping_type != activity_type:
            if mapping_type == MAPPING_POINTS:
                value_activity = active_points_from_cells(ugrid, dataset_activity)
            else:
                raise ValueError('Error: Dataset should never have cell mapped values with point activity.')
        else:
            value_activity = dataset_activity
        values[value_activity == 0] = np.nan
    elif activity_type == ACTIVITY_NULL_VALUE:
        values[values == null_value] = np.nan
    return values


class CellToPointActivityCalculator:
    """Class to calculate point activity from a cell activity array."""

    def __init__(self, ugrid):
        """Constructor.

        Args:
            ugrid (UGrid): The UGrid geometry to calculate point activity for
        """
        self.ugrid = ugrid

    def calc(self, cell_activity):
        """Calculate a point activity array given a cell activity array.

        Args:
            cell_activity (np.ndarray): The cell activity array

        Returns:
            np.ndarray: The point activity array
        """
        return active_points_from_cells(self.ugrid, cell_activity)
