"""MappedBcData class."""

# 1. Standard Python modules
import os

# 2. Third party modules
import numpy as np
# from packaging import version
import xarray as xr

# 3. Aquaveo modules
from xms.components.bases.xarray_base import XarrayBase
from xms.core.filesystem import filesystem as xfs

# 4. Local modules
from xms.adcirc.__version__ import version
from xms.adcirc.data import bc_data as bcd
from xms.adcirc.file_io.grid_crc import compute_grid_crc

MAPPED_BC_MAIN_FILE = 'mapped_bc_comp.nc'


class MappedBcData(XarrayBase):
    """Manages data file for the boundary conditions coverage hidden component."""
    def __init__(self, data_file):
        """Initializes the data class.

        Args:
            data_file (:obj:`str`): The netcdf file (with path) associated with this instance data. Probably the owning
                component's main file.

        """
        # Initialize member variables before calling super so they are available for commit() call
        self._filename = data_file
        self._info = None
        self._nodestrings = None
        self._nodes = None
        self._levees = None
        self._river_flows = self.get_dataset('river_flows', False)
        # Create the default file before calling super because we have our own attributes to write.
        self._get_default_datasets()
        super().__init__(self._filename)
        # This needs to happen after super call because we need to migrate stuff in the base `info` dataset.
        self._check_for_simple_migration()
        self.source_data = bcd.BcData(os.path.join(os.path.dirname(self._filename), bcd.BC_MAIN_FILE))

    def update_levee(self, comp_id, new_atts):
        """Update the levee attributes of a boundary condition (outflows or pairs).

        Args:
            comp_id (:obj:`int`): Component id of the levee to update
            new_atts (:obj:`xarray.Dataset`): The new attributes for the levee

        """
        sub_col = 'Subcritical __new_line__ Flow Coef'  # For cold golf, lines were too long.
        sup_col = 'Supercritical __new_line__ Flow Coef'
        diameter_col = 'Pipe __new_line__ Diameter (m)'
        pipe_coeff_col = 'Bulk __new_line__ Coefficient'
        self.levees['Zcrest (m)'].loc[comp_id] = new_atts['Zcrest (m)']
        self.levees[sub_col].loc[comp_id] = new_atts[sub_col]
        self.levees[sup_col].loc[comp_id] = new_atts[sup_col]
        self.levees['Pipe'].loc[comp_id] = new_atts['Pipe']
        self.levees['Zpipe (m)'].loc[comp_id] = new_atts['Zpipe (m)']
        self.levees[diameter_col].loc[comp_id] = new_atts[diameter_col]
        self.levees[pipe_coeff_col].loc[comp_id] = new_atts[pipe_coeff_col]

    @property
    def nodestrings(self):
        """Load the nodestrings dataset from disk.

        Returns:
            (:obj:`xarray.Dataset`): Dataset interface to the nodestrings datasets in the main file

        """
        if self._nodestrings is None:
            self._nodestrings = self.get_dataset('nodestrings', False)
        return self._nodestrings

    @nodestrings.setter
    def nodestrings(self, dset):
        """Setter for the nodestrings attribute."""
        if dset:
            self._nodestrings = dset

    @property
    def nodes(self):
        """Load the nodes dataset from disk.

        Returns:
            (:obj:`xarray.Dataset`): Dataset interface to the nodes dataset in the main file

        """
        if self._nodes is None:
            self._nodes = self.get_dataset('nodes', False)
        return self._nodes

    @nodes.setter
    def nodes(self, dset):
        """Setter for the nodes attribute."""
        if dset:
            self._nodes = dset

    @property
    def levees(self):
        """Load the levees dataset from disk.

        Returns:
            (:obj:`xarray.Dataset`): Dataset interface to the levees dataset in the main file

        """
        if self._levees is None:
            self._levees = self.get_dataset('levees', False)
        return self._levees

    @levees.setter
    def levees(self, dset):
        """Setter for the levees attribute."""
        if dset:
            self._levees = dset

    @property
    def river_flows(self):
        """Load the river_flows dataset from disk.

        Returns:
            (:obj:`xarray.Dataset`): Dataset interface to the river_flows dataset in the main file

        """
        if self._river_flows is None:
            self._river_flows = self.get_dataset('river_flows', False)
        return self._river_flows

    @river_flows.setter
    def river_flows(self, dset):
        """Setter for the river_flows attribute."""
        if dset:
            self._river_flows = dset

    def _get_default_datasets(self):
        """Create default datasets if needed."""
        if not os.path.isfile(self._filename):
            info = {
                'FILE_TYPE': 'ADCIRC_MAPPED_BC',
                'VERSION': version,
                'display_uuid': '',
                'cov_uuid': '',
                'grid_crc': '',
                'wkt': '',  # Display projection at time of mapping. Must match domain's native projection.
            }
            nodestring_data = {
                'comp_id': xr.DataArray(data=np.array([], dtype=np.int32)),
                'partner_id': xr.DataArray(data=np.array([], dtype=np.int32)),
                'nodes_start_idx': xr.DataArray(data=np.array([], dtype=np.int32)),
                'node_count': xr.DataArray(data=np.array([], dtype=np.int32)),
            }
            node_data = {
                'id': xr.DataArray(data=np.array([], dtype=np.int32)),
            }
            levee_data = {
                'Node1 Id': ('comp_id', np.array([], dtype=np.int32)),
                'Node2 Id': ('comp_id', np.array([], dtype=np.int32)),
                'Zcrest (m)': ('comp_id', np.array([], dtype=np.float64)),
                'Subcritical __new_line__ Flow Coef': ('comp_id', np.array([], dtype=np.float64)),
                'Supercritical __new_line__ Flow Coef': ('comp_id', np.array([], dtype=np.float64)),
                'Pipe': ('comp_id', np.array([], dtype=np.int32)),
                'Zpipe (m)': ('comp_id', np.array([], dtype=np.float64)),
                'Pipe __new_line__ Diameter (m)': ('comp_id', np.array([], dtype=np.float64)),
                'Bulk __new_line__ Coefficient': ('comp_id', np.array([], dtype=np.float64)),
            }
            coords = {'comp_id': []}
            self._info = xr.Dataset(attrs=info)
            self._nodestrings = xr.Dataset(data_vars=nodestring_data)
            self._nodes = xr.Dataset(data_vars=node_data)
            self._levees = xr.Dataset(data_vars=levee_data, coords=coords)
            self._river_flows = xr.Dataset(
                {
                    'Flow': (('Node Id', 'TS'), np.empty((0, 0)))
                },
                coords={
                    'Node Id': np.array([], dtype=int),
                    'TS': np.array([], dtype=np.float64),
                    'comp_id': ('Node Id', np.empty((0,), dtype=np.float64))
                }
            )
            self.commit()

    def _check_for_simple_migration(self):
        """Migrate things that we don't need an XML migration for."""
        commit = False
        grid_file = os.path.join(os.path.dirname(self._filename), self.info.attrs.get('grid_file', ''))
        if os.path.isfile(grid_file):
            if 'grid_crc' not in self.info.attrs:
                # If the `grid_crc` attr does not exist, this means the data was created back when we used to store the
                # CoGrid file with the component data.
                self.info.attrs['grid_crc'] = compute_grid_crc(grid_file)
                commit = True  # Need to commit because we won't get a second chance since the CoGrid file is gone
            # Now delete the CoGrid file because they can take up significant disk space. We use the existence of the
            # file as the check here because we need to delete it from temp and the project.
            xfs.removefile(grid_file)

        if self._river_flows is None:
            node_ids = np.array([], dtype=int)
            ts = np.array([], dtype=float)
            comp_ids = np.array([], dtype=int)
            self._river_flows = xr.Dataset(
                {
                    'Flow': (('Node Id', 'TS'), np.empty((0, 0)))
                },
                coords={
                    'Node Id': node_ids,
                    'TS': ts,
                    'comp_id': ('Node Id', comp_ids)
                }
            )
        if commit:
            self.commit()

    def set_levee_pair(self, nodestring1, nodestring2):
        """Define a levee pair association by nodestring id.

        Args:
            nodestring1 (:obj:`int`): Index of the first nodestring in the pair
            nodestring2 (:obj:`int`): Index of the second nodestring in the pair

        """
        self.nodestrings['partner_id'][nodestring1] = nodestring2
        self.nodestrings['partner_id'][nodestring2] = nodestring1

    def get_nonlevee_land_node_ids(self):
        """Get the ids of the non-levee land nodestring nodes in the order they will be written to the fort.14.

        Returns:
            (:obj:`tuple(list,list,list)`): List of all the non-levee land nodestring node ids,
            list of all the nodestrings' node counts, list of nodestring component id. Lists are parallel.

        """
        land_arcs = self.source_data.arcs.where((
            (self.source_data.arcs.type != bcd.OCEAN_INDEX) & (self.source_data.arcs.type != bcd.LEVEE_INDEX)  # noqa, W503
            & (self.source_data.arcs.type != bcd.LEVEE_OUTFLOW_INDEX)  # noqa, W503
            & (self.source_data.arcs.type != bcd.UNASSIGNED_INDEX)),  # noqa, W503
            drop=True)
        mask = self.nodestrings.comp_id.isin(land_arcs.comp_id)
        node_starts = self.nodestrings.nodes_start_idx[mask].data.tolist()
        node_counts = self.nodestrings.node_count[mask].data.tolist()
        land_nodes = [
            node.item() for node_start, node_count in zip(node_starts, node_counts)
            for node in self.nodes.id[node_start:node_start + node_count]
        ]
        comp_ids = self.nodestrings.comp_id[mask].data.tolist()
        return land_nodes, node_counts, comp_ids

    def get_ocean_node_ids(self):
        """Get the ids of the ocean nodestring nodes in the order they will be written to the fort.14.

        Returns:
            (:obj:`tuple(list,list)`): List of all the ocean nodestring node ids, list of all the ocean
            nodestrings' node counts.

        """
        ocean_arcs = self.source_data.arcs.where(self.source_data.arcs.type == bcd.OCEAN_INDEX, drop=True)
        mask = self.nodestrings.comp_id.isin(ocean_arcs.comp_id)
        node_starts = self.nodestrings.nodes_start_idx[mask].data.tolist()
        node_counts = self.nodestrings.node_count[mask].data.tolist()
        ocean_nodes = [
            node.item() for node_start, node_count in zip(node_starts, node_counts)
            for node in self.nodes.id[node_start:node_start + node_count]
        ]
        return ocean_nodes, node_counts

    def get_river_node_ids(self):
        """Get the ids of the river nodestring nodes in the order they will be written to the fort.14."""
        river_arcs = self.source_data.arcs.where(
            (self.source_data.arcs.type == bcd.RIVER_INDEX)
            | (self.source_data.arcs.type == bcd.FLOW_AND_RADIATION_INDEX),  # noqa, W503
            drop=True
        )
        mask = self.nodestrings.comp_id.isin(river_arcs.comp_id)
        node_starts = self.nodestrings.nodes_start_idx[mask].data.tolist()
        node_counts = self.nodestrings.node_count[mask].data.tolist()
        river_nodes = [
            node.item() for node_start, node_count in zip(node_starts, node_counts)
            for node in self.nodes.id[node_start:node_start + node_count]
        ]
        return river_nodes

    def set_river_flow_data(self, dt, flows_by_node, river_node_ids):
        """Get the ids of the river nodestring nodes in the order they will be written to the fort.14."""
        # nothing will use constant q
        self.source_data.arcs['constant_q'].values[:] = -999.0

        # get all of the river arcs
        river_arcs = self.source_data.arcs.where(self.source_data.arcs.type == bcd.RIVER_INDEX, drop=True)
        all_river_comp_ids = river_arcs['comp_id'].values.tolist()

        flow_times = [i * dt for i in range(len(flows_by_node[0]))]

        river_comp_ids_for_ds = []
        river_node_ids_for_ds = []
        river_flows_for_ds = []
        for comp_id in all_river_comp_ids:
            mask = self.nodestrings.comp_id.isin(comp_id)
            node_starts = self.nodestrings.nodes_start_idx[mask].data.tolist()
            node_counts = self.nodestrings.node_count[mask].data.tolist()
            river_nodes = [
                node.item() for node_start, node_count in zip(node_starts, node_counts)
                for node in self.nodes.id[node_start:node_start + node_count]
            ]
            for node in river_nodes:
                river_flows_for_ds.append(flows_by_node[river_node_ids.index(node)])
                river_node_ids_for_ds.append(node - 1)
                river_comp_ids_for_ds.append(comp_id)

        # create the dataset
        if len(river_comp_ids_for_ds) > 0:
            flow_array = np.array(river_flows_for_ds)
            self.river_flows = xr.Dataset(
                {
                    'Flow': (('Node Id', 'TS'), flow_array)
                },
                coords={
                    'Node Id': river_node_ids_for_ds,
                    'TS': flow_times,
                    'comp_id': ('Node Id', river_comp_ids_for_ds)
                }
            )

    def get_generic_node_ids(self):
        """Get the ids of the generic nodestring nodes in the order they will be written to the fort.14.

        Returns:
            (:obj:`tuple(list,list)`): List of all the generic nodestring node ids, list of all the ocean nodestrings'
            node counts.

        """
        generic_arcs = self.source_data.arcs.where(self.source_data.arcs.type == bcd.UNASSIGNED_INDEX, drop=True)
        mask = self.nodestrings.comp_id.isin(generic_arcs.comp_id)
        node_starts = self.nodestrings.nodes_start_idx[mask].data.tolist()
        node_counts = self.nodestrings.node_count[mask].data.tolist()
        ocean_nodes = [
            node.item() for node_start, node_count in zip(node_starts, node_counts)
            for node in self.nodes.id[node_start:node_start + node_count]
        ]
        return ocean_nodes, node_counts

    def commit(self):
        """Save current in-memory component parameters to data file."""
        super().commit()  # Recreates the NetCDF file if vacuuming
        if self._nodestrings is not None:
            self._nodestrings.close()
            self._drop_h5_groups(['nodestrings'])
            self._nodestrings.to_netcdf(self._filename, group='nodestrings', mode='a')
        if self._nodes is not None:
            self._nodes.close()
            self._drop_h5_groups(['nodes'])
            self._nodes.to_netcdf(self._filename, group='nodes', mode='a')
        if self._levees is not None:
            self._levees.close()
            self._drop_h5_groups(['levees'])
            self._levees.to_netcdf(self._filename, group='levees', mode='a')
        if self._river_flows is not None:
            self._river_flows.close()
            self._drop_h5_groups(['river_flows'])
            self._river_flows.to_netcdf(self._filename, group='river_flows', mode='a')

    def vacuum(self):
        """Rewrite all SimData to a new/wiped file to reclaim disk space.

        All BC datasets that need to be written to the file must be loaded into memory before calling this method.

        """
        if self._info is None:
            self._info = self.get_dataset('info', False)
        if self._nodestrings is None:
            self._nodestrings = self.get_dataset('nodestrings', False)
        if self._nodes is None:
            self._nodes = self.get_dataset('nodes', False)
        if self._levees is None:
            self._levees = self.get_dataset('levees', False)
        if self._river_flows is None:
            self._river_flows = self.get_dataset('river_flows', False)
        xfs.removefile(self._filename)  # Delete the existing NetCDF file
        self.commit()  # Rewrite all datasets
