import logging
import xarray as xr
import param

from .boundary_conditions import BoundaryConditions
from .model_control import ModelControl
from .mesh import Mesh
from .dat_reader import read_hot_start_file
from .vessels import VesselList

log = logging.getLogger('adhmodel.simulation')


class AdhSimulation(param.Parameterized):
    boundary_conditions = param.ClassSelector(default=BoundaryConditions(), class_=BoundaryConditions)
    model_control = param.ClassSelector(default=ModelControl(), class_=ModelControl)
    mesh = param.ClassSelector(default=Mesh(), class_=Mesh)
    results = param.ClassSelector(default=xr.Dataset(), class_=xr.Dataset)
    hotstart = param.ClassSelector(default=xr.Dataset(), class_=xr.Dataset)
    vessels = param.ClassSelector(default=VesselList(), class_=VesselList)

    def __init__(self, **params):
        super(param.Parameterized, self).__init__(**params)
        self.mesh = Mesh()
        self.boundary_conditions = BoundaryConditions()
        self.model_control = ModelControl()
        self.results = xr.Dataset()
        self.hotstart = xr.Dataset()
        self.vessel_list = VesselList()

    def read_bc(self, *args, **kwargs):
        self.boundary_conditions.read(*args, **kwargs)

    def write_bc(self, *args, **kwargs):
        self.boundary_conditions.write(*args, **kwargs)

    def read_hotstart(self, path, fmt='nc'):
        if fmt == 'nc':
            model_xr = xr.open_dataset(path)
            self.hotstart = self._subset_xarray_vars_by_coords(model_xr, ('init_time', 'nodes_ids'))
        elif fmt == 'ascii':
            try:
                self.hotstart = read_hot_start_file(path)
            except ValueError as e:
                log.warning(e)

    def _subset_xarray_vars_by_coords(self, dataset, coords):
        # determine what attributes are results
        coords_set = set(coords)
        subset_vars = []
        for var in dataset.data_vars:
            dims_set = set(dataset[var].dims)
            if coords_set.issubset(dims_set):
                subset_vars.append(var)
        return dataset[subset_vars]
