"""Helper class to get the XMS data for the fort.22 script."""

# 1. Standard Python modules
import datetime
import os

# 2. Third party modules
import orjson

# 3. Aquaveo modules
from xms.api.tree import tree_util
from xms.core.filesystem import filesystem as xfs
from xms.coverage.windCoverage import WindCoverage
from xms.datasets.dataset_reader import DatasetReader

# 4. Local modules
from xms.adcirc.data.sim_data import SimData
from xms.adcirc.feedback.xmlog import XmLog

FORT22_EXPORT_ARGS_JSON = 'fort22_input_data.json'


class Fort22DataGetter:
    """Helper class to get the XMS data for the fort.22 script."""
    def __init__(self, query, xms_data):
        """Constructor.

        Args:
            query (:obj:`Query`): The XMS interprocess communicator
            xms_data (:obj:`dict`): Dict of the XMS data to fill
        """
        self._query = query
        self._xms_data = xms_data

    def _retrieve_sim_data(self):
        """Retrieve the simulation-level data."""
        self._xms_data['radstress'] = None
        self._xms_data['error'] = False
        sim_export = True
        # Get the simulation's component data.
        try:
            self._xms_data['global_time'] = self._query.global_time  # Most types need the global time, so throw it in
            # If we don't already have sim_data, implies this is a partial export command, starting Context is at the
            # simulation component level.
            if 'sim_data' not in self._xms_data:
                self._xms_data['sim_data'] = SimData(self._query.current_item().main_file)
                sim_export = False
        except Exception:
            self._xms_data['error'] = True
            raise RuntimeError('Unable to retrieve ADCIRC simulation data from SMS. fort.22 was not exported.')

        # Get the simulation tree node.
        try:
            if sim_export:
                sim_uuid = self._query.current_item_uuid()
            else:
                sim_uuid = self._query.parent_item_uuid()
            self._xms_data['sim_node'] = tree_util.find_tree_node_by_uuid(self._query.project_tree, sim_uuid)
        except Exception:
            self._xms_data['error'] = True
            raise RuntimeError('Unable to retrieve ADCIRC simulation data from SMS. fort.22 was not exported.')

        # Get the radstress dataset if applicable
        if self._xms_data['sim_data'].wind.attrs['wave_radiation'] == 100:
            self._xms_data['radstress'] = self._query.item_with_uuid(
                self._xms_data['sim_data'].wind.attrs['fort23_uuid']
            )
            if not self._xms_data['radstress']:
                self._xms_data['error'] = True
                raise RuntimeError('Unable to retrieve radiation stress dataset to write fort.23.')

    def _retrieve_data_for_nws_mesh(self):
        """Get XMS data for the mesh dataset types (NWS=1,2,5)."""
        XmLog().instance.info('Retrieving the mesh wind datasets from XMS...')
        try:
            # Select the datasets by UUID.
            pressure_dset = self._xms_data['sim_data'].wind.attrs['mesh_pressure']
            vector_dset = self._xms_data['sim_data'].wind.attrs['mesh_wind']
            self._xms_data['pressure'] = self._query.item_with_uuid(pressure_dset)
            self._xms_data['vector'] = self._query.item_with_uuid(vector_dset)
            if not self._xms_data['pressure'] or not self._xms_data['vector']:
                raise RuntimeError()
        except Exception:
            self._xms_data['error'] = True
            raise RuntimeError('Unable to retrieve domain mesh wind datasets from SMS. fort22 file was not exported.')

    def _retrieve_data_for_nws_grid(self):
        """Get XMS data for NWS=3 type as well as some of the data for NWS=6,7."""
        XmLog().instance.info('Retrieving the wind grid and velocity dataset from SMS...')
        try:
            # two options supported:
            # 1 - user has selected a grid with datasets to define fort.22
            # 2 - user has specified to use an existing fort.22 file
            # find which is active
            existing_file = False
            if self._xms_data['sim_data'].wind.attrs['use_existing']:
                existing_file = True
            if not existing_file:
                # Get the wind grid if we haven't already
                if not self._xms_data.get('wind_grid'):
                    wind_grid_item = tree_util.descendants_of_type(
                        self._xms_data['sim_node'], xms_types=['TI_CGRID2D_PTR'], allow_pointers=True, only_first=True
                    )
                    do_ugrid = self._query.item_with_uuid(wind_grid_item.uuid)
                    self._xms_data['wind_grid'] = do_ugrid.cogrid_file
                velocity_uuid = self._xms_data['sim_data'].wind.attrs['grid_wind']
                self._xms_data['velocity'] = self._query.item_with_uuid(velocity_uuid)
                # Problem - we couldn't find the wind grid
                if not self._xms_data['wind_grid'] or not self._xms_data['velocity']:
                    raise RuntimeError()
        except Exception:
            self._xms_data['error'] = True
            raise RuntimeError(
                'Unable to retrieve wind grid and velocity dataset from SMS. fort22 file was not exported.'
            )

    def _retrieve_data_for_nws_6_7_grid(self):
        """Retrieve data for the NWS=6,7 types that is not applicable to NWS=3."""
        self._retrieve_data_for_nws_grid()  # Base will get most of the data we need.
        XmLog().instance.info('Retrieving the wind pressure dataset from SMS...')
        try:
            # if we are not using an existing file, get the pressure dataset
            existing_file = False
            if self._xms_data['sim_data'].wind.attrs['use_existing']:
                existing_file = True
            if not existing_file:
                # Grab the additional wind pressure dataset not used by NWS=3 types.
                pressure_uuid = self._xms_data['sim_data'].wind.attrs['grid_pressure']
                self._xms_data['pressure'] = self._query.item_with_uuid(pressure_uuid)
                if not self._xms_data['pressure']:
                    raise RuntimeError()
        except Exception:
            self._xms_data['error'] = True
            raise RuntimeError('Unable to retrieve wind grid pressure from SMS. fort22 file was not exported.')

    def _retrieve_storm_track_data(self):
        """Retrieve data for the storm track NWS types (8, 19, 20)."""
        if self._xms_data['sim_data'].wind.attrs['use_existing']:
            return  # No storm track coverage if using an existing fort.22
        XmLog().instance.info('Retrieving the wind track coverage from SMS...')
        try:
            wind_cov_item = tree_util.descendants_of_type(
                self._xms_data['sim_node'],
                xms_types=['TI_COVER_PTR'],
                allow_pointers=True,
                only_first=True,
                recurse=True,
                coverage_type='WIND'
            )
            self._xms_data['wind_cov'] = self._query.item_with_uuid(wind_cov_item.uuid, generic_coverage=True)
            if not self._xms_data['wind_cov']:
                raise RuntimeError()
        except Exception:
            self._xms_data['error'] = True
            raise RuntimeError('Unable to retrieve wind track coverage from SMS. fort22 file was not exported.')

    def retrieve_data(self):
        """Get all the data from XMS to export the fort.22 for any NWS type.

        Returns:
            (:obj:`dict`): The XMS data needed to export the fort.22 for a simulation's particular NWS type
        """
        self._retrieve_sim_data()
        nws_simple = self._xms_data['sim_data'].wind.attrs['NWS']
        if nws_simple == 0:  # No wind, nothing to retrieve
            return self._xms_data

        # Get data specific to each NWS type
        if nws_simple in [1, 2, 4, 5]:
            self._retrieve_data_for_nws_mesh()
        elif nws_simple == 3:
            self._retrieve_data_for_nws_grid()
        elif nws_simple in [6, 7]:
            self._retrieve_data_for_nws_6_7_grid()
        elif nws_simple in [8, 19, 20]:
            self._retrieve_storm_track_data()
        return self._xms_data


def write_fort22_data_json(xms_data):
    """Write a JSON file with data needed to write the fort.22 (only applicable to full simulation exports).

    Args:
        xms_data (:obj:`dict`): The XMS data required to export the fort.22

    Returns:
        (:obj:`bool`): True if no error detected
    """
    if xms_data.get('error'):
        return False  # Had a problem retrieving XMS data
    nws_simple = xms_data['sim_data'].wind.attrs['NWS']
    # Everybody needs the simulation data, and everybody could potentially be using a radstress dataset.
    radstress = xms_data.get('radstress')
    radstress = None if radstress is None else (radstress.h5_filename, radstress.group_path)
    json_data = {
        'sim_data': xms_data['sim_data']._filename,
        'radstress': radstress,
        'global_time': xms_data.get('global_time'),  # Will be needed if radstress is being used
    }
    # Add data specific to the NWS type
    if nws_simple in [1, 2, 4, 5]:  # Datasets on the domain mesh
        json_data['vector'] = (xms_data['vector'].h5_filename, xms_data['vector'].group_path)
        json_data['pressure'] = (xms_data['pressure'].h5_filename, xms_data['pressure'].group_path)
    elif nws_simple in [3, 6, 7]:  # Datasets on a wind grid
        if not xms_data['sim_data'].wind.attrs['use_existing']:
            json_data['wind_grid'] = xms_data['wind_grid']
            json_data['velocity'] = (xms_data['velocity'].h5_filename, xms_data['velocity'].group_path)
            if nws_simple in [6, 7]:  # 3 does not have the pressure dataset
                json_data['pressure'] = (xms_data['pressure'].h5_filename, xms_data['pressure'].group_path)
    elif nws_simple in [8, 19, 20] and xms_data.get('wind_cov') is not None:  # Storm track types
        json_data['wind_cov'] = xms_data['wind_cov'].m_fileName
    # Write the JSON file
    with open(os.path.join(os.getcwd(), FORT22_EXPORT_ARGS_JSON), 'wb') as file:
        data = orjson.dumps(json_data)
        file.write(data)
    return True


def read_fort22_data_json():
    """Read the JSON file of XMS input data written by the fort.15 script during a full simulation export.

    Returns:
        (:obj:`dict`): The XMS data required to export the fort.22
    """
    filename = os.path.join(os.getcwd(), FORT22_EXPORT_ARGS_JSON)
    with open(filename, 'rb') as file:
        xms_data = orjson.loads(file.read())
    xfs.removefile(filename)
    # Convert global time str to datetime. orjson dumps to RFC 3339 format. datetime.datetime.fromisoformat() would be
    # preferable here, but it was not added until 3.7.
    try:  # Don't think most will have fractional seconds
        xms_data['global_time'] = datetime.datetime.strptime(xms_data['global_time'], '%Y-%m-%dT%H:%M:%S')
    except ValueError:  # But if that fails, try parsing again with fractional seconds specifier
        xms_data['global_time'] = datetime.datetime.strptime(xms_data['global_time'], '%Y-%m-%dT%H:%M:%S.%f')
    # Reconstruct the simulation data
    xms_data['sim_data'] = SimData(xms_data['sim_data'])
    # Reconstruct the DatasetReaders from filenames and group paths in the JSON file.
    datasets = ['radstress', 'vector', 'velocity', 'pressure']
    for dataset in datasets:
        h5_info = xms_data.get(dataset)
        if h5_info is not None:
            xms_data[dataset] = DatasetReader(h5_filename=h5_info[0], group_path=h5_info[1])
    # Reconstruct the WindCoverage if there is one
    wind_cov = xms_data.get('wind_cov')
    if wind_cov:
        xms_data['wind_cov'] = WindCoverage(wind_cov)
    return xms_data
