"""TUFLOWFV DATV solution file importer."""
# 1. Standard python modules
import logging
import os
import struct
import uuid
import warnings

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.api.dmi import XmsEnvironment as XmEnv
from xms.constraint.ugrid_activity import active_points_from_cells
from xms.data_objects.parameters import julian_to_datetime
from xms.datasets.dataset_writer import DatasetWriter

# 4. Local modules
from xms.tuflowfv.file_io import io_util


class DatvSolutionReader:
    """Class for reading TUFLOWFV DATV solutions."""
    END_OF_HEADER = 7 * io_util.UINT_SIZE_BYTES  # Header is 7 unsigned integers
    # DATV Command enumeration
    DATV_BEGIN_SCALAR = 130
    DATV_BEGIN_VECTOR = 140
    # DATV_VECTOR_TYPE = 150 UNUSED
    # DATV_OBJECT_ID = 160 UNUSED
    DATV_NUMDATA = 170
    # DATV_NUMCELLS = 180 UNUSED
    DATV_NAME = 190
    # DATV_NULLVALUE = 185 UNUSED - TUFLOWFV writes a cell-based activity array
    DATV_REFERENCE_TIME = 240
    DATV_TIME_UNITS = 250
    DATV_TIME_STEP = 200
    DATV_END_OF_DATA = 210
    # DATV_ACTIVE_DATASET_AND_TIMESTEP = 220 UNUSED
    # DATV_MAPPED_DATASET = 230 UNUSED

    TIME_KEYWORDS = {
        0: 'Hours',
        1: 'Minutes',
        2: 'Seconds',
        3: 'None',
        4: 'Days',
    }

    def __init__(self, filenames, ugrid_item, xmugrid):
        """Constructor.

        Args:
            filenames (list[str]): Absolute paths to the DATV solution files
            ugrid_item (TreeNode): The linked Mesh2D or UGrid module geometry tree item
            xmugrid (UGrid): The linked Mesh2D or UGrid module geometry
        """
        self._logger = logging.getLogger('xms.tuflowfv')
        self._filenames = filenames
        self._xmugrid = xmugrid
        self._ugrid_item = ugrid_item
        self._ugrid_uuid = self._ugrid_item.uuid
        self._temp_folder = XmEnv.xms_environ_temp_directory()
        self._builders = []  # [DatasetWriter] - the imported solution datasets

    def _default_metadata(self):
        """Return a default metadata dict.

        Returns
            (dict): default metadata
        """
        return {
            self.DATV_BEGIN_SCALAR: False,
            self.DATV_BEGIN_VECTOR: False,
            self.DATV_NUMDATA: 0,
            self.DATV_NAME: 'Dataset',
            self.DATV_REFERENCE_TIME: None,
            self.DATV_TIME_UNITS: 'Hours',
        }

    def _read_header(self, file):
        """Reads the header from a datv file.

        Args:
            file (opened binary file): the datv file
        """
        # Skip header for now, don't need it for reading TUFLOWFV DATV files.
        # Header contains 7 variables at 4 bytes:
        # version, object card, object type, float size, size float, flag size, size flag
        file.seek(self.END_OF_HEADER)

    def _read_metadata_command(self, file, command, metadata):
        """Check if a command is a metadata value, and read it if we care about it.

        Args:
            file: Open file stream to the DATV file
            command (int): The previously parsed command integer from the stream
            metadata (dict): The pre-initialized dataset metadata dict

        Returns:
            bool: True if the command was parsed, False if it was not a metadata command or it was a metadata command
                we don't care about
        """
        if command == self.DATV_BEGIN_SCALAR:
            metadata[self.DATV_BEGIN_SCALAR] = True
        elif command == self.DATV_BEGIN_VECTOR:
            metadata[self.DATV_BEGIN_VECTOR] = True
        elif command == self.DATV_NUMDATA:
            metadata[self.DATV_NUMDATA] = io_util.read_binary_into_int(file)
        elif command == self.DATV_NAME:
            metadata[self.DATV_NAME] = io_util.read_binary_into_str(file, io_util.MAX_STR_SIZE_BYTES).strip()
        elif command == self.DATV_REFERENCE_TIME:
            temp_var = io_util.read_binary_into_float(file, io_util.FLOAT_64_SIZE_BYTES)
            metadata[self.DATV_REFERENCE_TIME] = julian_to_datetime(temp_var)
        elif command == self.DATV_TIME_UNITS:
            temp_var = io_util.read_binary_into_int(file)
            metadata[self.DATV_TIME_UNITS] = self.TIME_KEYWORDS.get(temp_var, 'Hours')
        else:
            return False
        return True

    def _read_scalar_timestep(self, file):
        """Reads the scalar timestep from a datv file.

        Args
            file (opened binary file): the datv file

        Returns
            tuple(float, Sequence, Sequence):
        """
        time, cell_activity, point_activity = self._read_time_and_activity(file)
        file_bytes = file.read(self._ugrid_item.num_points * io_util.FLOAT_32_SIZE_BYTES)
        fmt = f'<{self._ugrid_item.num_points}f'
        scalar_values = np.array(struct.unpack(fmt, file_bytes))
        if point_activity is not None:
            scalar_values[point_activity != 1] = np.nan
        return time, scalar_values, cell_activity

    def _read_vector_timestep(self, file):
        """Reads the scalar timestep from a datv file.

        Args
            file (opened binary file): the datv file

        Returns
            tuple(float, Sequence, Sequence):
        """
        time, cell_activity, point_activity = self._read_time_and_activity(file)
        num_vals = self._ugrid_item.num_points * 2  # 2 32-bit floats per point
        file_bytes = file.read(num_vals * io_util.FLOAT_32_SIZE_BYTES)
        fmt = f'<{num_vals}f'
        flat_array = np.array(struct.unpack(fmt, file_bytes))
        vx = flat_array[::2]  # Transform from [x1, y1, ..., xN, yN] -> [[x1, y1], ..., [xN, yN]]
        vy = flat_array[1::2]
        vector_values = np.stack([vx, vy], axis=1)
        if point_activity is not None:
            vector_values[point_activity != 1] = np.nan
        return time, vector_values, cell_activity

    def _read_time_and_activity(self, file):
        """Reads the timestep offset and activity array for a timestep.

        Args:
            file: Open file stream to the DATV solution file

        Returns:
            tuple(float, Sequence, Sequence): The timestep offset, the cell activity array, the point activity array
        """
        active_flag = io_util.read_binary_into_bool(file)
        time = io_util.read_binary_into_float(file, io_util.FLOAT_32_SIZE_BYTES)
        cell_activity = None
        point_activity = None
        if active_flag:
            file_bytes = file.read(self._ugrid_item.num_cells)
            fmt = f'<{self._ugrid_item.num_cells}?'
            cell_activity = struct.unpack(fmt, file_bytes)
            point_activity = active_points_from_cells(self._xmugrid, cell_activity)
        return time, cell_activity, point_activity

    def _read_dataset(self, filename):
        """Read a dataset from a DATV file (we are assuming only one dataset per file).

        Args:
            filename (str): Absolute filepath to the DATV solution file
        """
        try:
            file_size = os.path.getsize(filename)
        except Exception:
            raise RuntimeError(f'Empty or nonexistent DATV file: {io_util.logging_filename(filename)}')

        times = []
        values = []
        cell_activity = []
        metadata = self._default_metadata()
        self._logger.info(f'Reading DATV solution file: "{io_util.logging_filename(filename)}"...')
        with open(filename, 'rb', buffering=io_util.READ_BUFFER_SIZE) as f:
            self._read_header(f)
            cur_pos = self.END_OF_HEADER
            while cur_pos < file_size:
                command = io_util.read_binary_into_int(f)
                if self._read_metadata_command(f, command, metadata):
                    continue
                elif command == self.DATV_TIME_STEP:
                    if metadata[self.DATV_BEGIN_VECTOR]:  # Read a vector timestep
                        time, value, activity_mask = self._read_vector_timestep(f)
                    else:  # Read a scalar timestep
                        time, value, activity_mask = self._read_scalar_timestep(f)
                    times.append(time)
                    values.append(value)
                    if activity_mask is not None:
                        cell_activity.append(activity_mask)
                elif command == self.DATV_END_OF_DATA:
                    break  # Done reading the dataset. Assuming only one per file, which is the case with TUFLOWFV.
                cur_pos = f.tell()
        if len(times) > 0:
            self._write_datv_dataset_to_xmdf(metadata, times, values, cell_activity)
        else:
            self._logger.error(f'Empty or nonexistent DATV file: {io_util.logging_filename(filename)}')

    def _write_datv_dataset_to_xmdf(self, metadata, times, values, cell_activity):
        """Write a DATV dataset to XMDF format.

        Args:
            metadata (dict): The datasets metadata
            times (Sequence): The timestep offset dataset
            values (Sequence): The values dataset
            cell_activity (Union[Sequence, None]): The activity array on cells, if defined
        """
        self._logger.info('Writing imported dataset to XMDF file...')
        dset_uuid = str(uuid.uuid4())
        h5_filename = f'{os.path.join(self._temp_folder, dset_uuid)}.h5'
        num_components = 1 if metadata[self.DATV_BEGIN_SCALAR] else 2
        # SMS seems to have trouble if we leave the vector values at NaN. Doesn't seem to be a problem with scalars.
        null_value = 0.0 if num_components == 2 else None
        writer = DatasetWriter(h5_filename=h5_filename, name=metadata[self.DATV_NAME], dset_uuid=dset_uuid,
                               geom_uuid=self._ugrid_uuid, num_components=num_components, null_value=null_value,
                               ref_time=metadata[self.DATV_REFERENCE_TIME], time_units=metadata[self.DATV_TIME_UNITS])
        if len(cell_activity) == 0:
            cell_activity = None  # If no activity defined, don't send in an empty list or dimensions will mismatch.
        writer.write_xmdf_dataset(times, values, activity=cell_activity)
        self._builders.append(writer)

    def read(self):
        """Import all the TUFLOWFV DATV solution datasets.

        Note that this reader is specific to TUFLOWFV. We skip much of the header because we know the sizes and data
        types of the data TUFLOWFV writes to its solution files and it is unlikely to change. We also assume point
        based data with a cell-based activity array because that is what TUFLOWFV outputs. We also assume there is
        only one dataset per file.

        Returns:
            list[DatasetWriter]: The imported datasets
        """
        # Ignore numpy warning about all-nan slices (cells that are dry across all timesteps).
        with warnings.catch_warnings():
            warnings.filterwarnings(action='ignore', message='All-NaN slice encountered')
            for filename in self._filenames:
                try:  # If reading one barfs, continue reading other files.
                    self._read_dataset(filename)
                except Exception as e:
                    self._logger.error('Errors reading DATV file: {str(e)}')
        return self._builders
