"""StructureData class."""
# 1. Standard python modules
import os

# 2. Third party modules
import numpy as np
import pkg_resources
import xarray as xr

# 3. Aquaveo modules
from xms.components.bases.xarray_base import XarrayBase
from xms.core.filesystem import filesystem as io_util
from xms.guipy.data.target_type import TargetType

# 4. Local modules
from xms.tuflowfv.components.tuflowfv_component import UNINITIALIZED_COMP_ID
from xms.tuflowfv.data.tuflowfv_data import check_for_object_strings_dumb
from xms.tuflowfv.file_io import culvert_csv_io as cci

STRUCTURE_MAIN_FILE = 'struct_comp.nc'
FLUX_FUNCTION_TYPES = (
    'Weir',
    'Culvert',
    'Nlswe',
)
FLUX_FUNCTION_SET_TYPES = (
    'Culvert',
)
FLUX_FUNCTION_UNSUPPORTED_STRUCTS = {
    'Weir': ('linked zones', 'linked nodestrings'),
    'Culvert': (),
    'Nlswe': ('linked zones', 'linked nodestrings'),
}
DEFAULT_FLUX_FUNCTION_ARC = 'Weir'
DEFAULT_FLUX_FUNCTION_POLY = 'unassigned'
STRUCTURE_STRING_VARIABLES = {  # Variables that have a str dtype, have to make fixed size before serializing to NetCDF
    'struct_type',
    'flux_function',
    'name',
    'connection',  # cscott: I don't think this is used.
    'energy_loss_function',
    'type',
    'blockage_file',
    'width_file',
    'energy_loss_file',
    'flux_file',
}

FILE_VARIABLES = {}
FILE_VARIABLES_OLD = {
    'blockage_file',
    'width_file',
    'energy_loss_file',
    'flux_file',
}


def _add_culvert_defaults(defaults, dtypes):
    """Append the culvert attributes to the structure defaults.

    Args:
        defaults (dict): Default dict to update
        dtypes (list): dtype list to update
    """
    # Add the culvert attribute names
    defaults.update({name.lower(): data[1] for name, data in cci.CULVERT_COLUMNS.items()})
    dtypes.extend([val[0] for val in cci.CULVERT_COLUMNS.values()])


def get_default_structure_data_dict():
    """Returns a default data dict and the dtypes for a BC.

    Returns:
        tuple(dict, list): The default data and their dtypes
    """
    defaults = {
        'struct_type': 'nodestring',  # Only used when importing.
        'flux_function': 'Weir',
        'name': '',
        'connection': '',  # String names. Only used when importing.
        'upstream': 1,  # 0 == downstream, 1 == upstream
        'weir_z': 0.0,  # Required
        'elevation_is_dz': 1,
        'define_weir': 0,
        'weir_cw': 1.705,
        'weir_ex': 1.5,
        'weir_a': 8.55,
        'weir_b': 0.556,
        'weir_csfm': 0.7,
        'energy_loss_function': 'Coefficient',
        'width_or_blockage': 'None',
        'form_loss_coefficient': 0.0,
        'blockage_file': '',
        'width_file': '',
        'energy_loss_file': '',
        # Only linked nodestrings
        'zone_inlet_orientation': 0.0,
        'zone_outlet_orientation': 0.0,
        'flux_file': '',  # Matrix only
        # Culvert parameters
        'type': 1,  # 1=circular, 2=rectangular, 4=gated circular (unidirectional), 5=gated rectangular
        'ignore': 0,
        'ucs': -1,
        'len_or_ana': 0.0,
        'n_or_n_f': 0.0,
        'us_invert': 0.0,
        'ds_invert': 0.0,
        'form_loss': 0.0,
        'pblockage': 0,
        'inlet_type': -1,
        'conn_2d': -1,
        'conn_no': -1,
        'width_or_dia': 0.0,
        'height_or_wf': np.float64,
        'number_of': 0,
        'height_cont': 0.6,
        'width_cont': 1.0,
        'entry_loss': 0.5,
        'exit_loss': 1.0,
    }
    dtypes = [
        object,      # struct_type
        object,      # flux_function
        object,      # name
        object,      # connection
        np.int32,    # upstream
        np.float64,  # weir_z
        np.int32,    # elevation_is_dz
        np.int32,    # define_weir
        np.float64,  # weir_cw
        np.float64,  # weir_ex
        np.float64,  # weir_a
        np.float64,  # weir_b
        np.float64,  # weir_csfm
        object,      # energy_loss_function
        object,      # width_or_blockage
        np.float64,  # form_loss_coefficient
        object,      # blockage_file
        object,      # width_file
        object,      # energy_loss_file
        np.float64,  # zone_inlet_orientation
        np.float64,  # zone_outlet_orientation
        object,      # flux_file
        np.int32,    # type
        np.int32,    # ignore
        np.int32,    # ucs
        np.float64,  # len_or_ana
        np.float64,  # n_or_n_f
        np.float64,  # us_invert
        np.float64,  # ds_invert
        np.float64,  # form_loss
        np.int32,    # pblockage
        np.int32,    # inlet_type
        np.int32,    # conn_2d
        np.int32,    # conn_no
        np.float64,  # width_or_dia
        np.float64,  # height_or_wf
        np.int32,    # number_of
        np.float64,  # height_cont
        np.float64,  # width_cont
        np.float64,  # entry_loss
        np.float64,  # exit_loss
    ]
    _add_culvert_defaults(defaults, dtypes)
    return defaults, dtypes


def get_default_structure_data(fill):
    """Returns a default data dict for a structure arc or polygon.

    Args:
        fill (bool): True if Dataset should be initialized with an default row, False if it should be empty

    Returns:
        dict: The default data
    """
    defaults, dtypes = get_default_structure_data_dict()
    return {
        variable: ('comp_id', np.array([value] if fill else [], dtype=dtype))
        for dtype, (variable, value) in zip(dtypes, defaults.items())
    }


def get_default_structure_dataset(fill, comp_id=None):
    """Returns a default xarray Dataset for a structure arc or polygon.

    Args:
        fill (bool): True if Dataset should be initialized with a default row, False if it should be empty
        comp_id (int): The comp_id to assign to the new dataset. If fill=True, comp_id must be provided.

    Returns:
        xarray.Dataset: The default dataset
    """
    if fill and comp_id is None:
        raise AttributeError('If fill=True, comp_id must be provided')

    table = get_default_structure_data(fill=fill)
    if comp_id is None:
        coords = {'comp_id': np.array([], np.int32)}
    else:
        coords = {'comp_id': np.array([comp_id], np.int32)}
    return xr.Dataset(data_vars=table, coords=coords)


class StructureData(XarrayBase):
    """Manages data file for the structures coverage hidden component."""

    def __init__(self, data_file):
        """Initializes the data class.

        Args:
            data_file (str): The netcdf file (with path) associated with this instance data. Probably the owning
                component's main file.
        """
        # Initialize member variables before calling super so they are available for commit() call
        self._filename = data_file
        self._arcs = None
        self._polygons = None
        self._sets = None
        # Create the default file before calling super because we have our own attributes to write.
        super().__init__(data_file)
        self._get_default_datasets()

    @staticmethod
    def _get_info_attrs():
        return {
            'FILE_TYPE': 'TUFLOWFV_STRUCTURES',
            'GENERIC_MODEL': 'TUFLOWFV',
            'VERSION': pkg_resources.get_distribution('xmstuflowfv').version,
            'arc_display_uuid': '',
            'poly_display_uuid': '',
            'cov_uuid': '',
            'next_comp_id': 0,
            'next_set_id': 0,
            'proj_dir': '',
            'export_format': 'Shapefile',
        }

    def _get_default_datasets(self):
        """Create default datasets if needed."""
        if not os.path.isfile(self._filename):
            self._info = xr.Dataset(attrs=self._get_info_attrs())
            arc_table = get_default_structure_data(fill=False)
            coords = {'comp_id': np.array([], np.int32)}
            self._arcs = xr.Dataset(data_vars=arc_table, coords=coords)
            polygon_table = get_default_structure_data(fill=False)
            self._polygons = xr.Dataset(data_vars=polygon_table, coords=coords)
            self._sets = self.add_linked_structure(comp1_id=None, comp2_id=None)

    @staticmethod
    def _update_variable(comp_id, var_name, dset, new_atts):
        """Update the value of a single variable.

        Args:
            comp_id (int): comp_id coord of the row to update
            var_name (str): Name of the variable to update
            dset (xr.Dataset): The old attribute dataset
            new_atts (xr.Dataset): The new dataset values
        """
        # Make sure to convert unicode dtypes to object before assignment to make sure we don't cut off text
        if var_name in STRUCTURE_STRING_VARIABLES:
            dset[var_name] = dset[var_name].astype(np.object_)
        dset[var_name].loc[dict(comp_id=[comp_id])] = new_atts[var_name].values[0]

    @property
    def arcs(self):
        """Load the arcs dataset from disk.

        Returns:
            xarray.Dataset: Dataset interface to the arcs dataset in the main file
        """
        if self._arcs is None:
            self._arcs = self.get_dataset('arcs', False)
        return self._arcs

    @arcs.setter
    def arcs(self, dset):
        """Setter for the arcs dataset."""
        if dset:
            self._arcs = dset

    @property
    def polygons(self):
        """Load the polygons dataset from disk.

        Returns:
            xarray.Dataset: Dataset interface to the polygons dataset in the main file
        """
        if self._polygons is None:
            self._polygons = self.get_dataset('polygons', False)
        return self._polygons

    @polygons.setter
    def polygons(self, dset):
        """Setter for the polygons dataset."""
        if dset:
            self._polygons = dset

    @property
    def sets(self):
        """Load the sets dataset from disk.

        Returns:
            xarray.Dataset: Dataset interface to the sets dataset in the main file
        """
        if self._sets is None:
            self._sets = self.get_dataset('sets', False)
        return self._sets

    @sets.setter
    def sets(self, dset):
        """Setter for the sets dataset."""
        if dset:
            self._sets = dset

    def update_structure(self, target_type, comp_id, new_atts):
        """Update the structure attributes of a hydraulic structure.

        Args:
            target_type (TargetType): The location of the feature
            comp_id (int): Component id of the BC to update
            new_atts (xarray.Dataset): The new attributes for the structure
        """
        dset = self.arcs if target_type == TargetType.arc else self.polygons
        var_names = get_default_structure_data_dict()[0].keys()
        for var_name in var_names:
            self._update_variable(comp_id, var_name, dset, new_atts)

    def update_struct_type(self, target_type, is_set, comp_id):
        """Updates the struct_type attribute of a structure dataset based on the given parameters.

        Args:
            target_type (TargetType): The target type of the structure dataset (arc or polygon).
            is_set (bool): Indicates if the structure dataset is a set or not.
            comp_id (int): The component ID of the structure.
        """
        struct_type = ''
        if is_set:
            struct_type += 'linked '
        if target_type == TargetType.arc:
            struct_type += 'nodestrings' if is_set else 'nodestring'
        else:  # TargetType.polygon
            struct_type += 'zones'  # As far as we know, the only polygon structure has to be a pair.
        dset = self.arcs if target_type == TargetType.arc else self.polygons
        comp_id = float(comp_id)
        dset['struct_type'].loc[dict(comp_id=[comp_id])] = struct_type

    def add_structure_atts(self, target_type, dset=None):
        """Add the structure attribute dataset for an arc/polygon.

        Args:
            dset (xarray.Dataset): The attribute dataset to concatenate. If not provided, a new Dataset of default
                attributes will be generated.
            target_type (TargetType): The feature type

        Returns:
            int: The newly generated component id
        """
        try:
            new_comp_id = self.info.attrs['next_comp_id']
            self.info.attrs['next_comp_id'] += 1  # Increment the unique XMS component id.
            if dset is None:  # Generate a new default Dataset
                dset = get_default_structure_dataset(fill=True, comp_id=new_comp_id)
            else:  # Update the component id of an existing Dataset
                dset.coords['comp_id'] = [new_comp_id for _ in dset.coords['comp_id']]
            if target_type == TargetType.arc:
                self._arcs = xr.concat([self.arcs, dset], 'comp_id')
            elif target_type == TargetType.polygon:
                self._polygons = xr.concat([self.polygons, dset], 'comp_id')
            return new_comp_id
        except Exception:
            return UNINITIALIZED_COMP_ID

    def add_linked_structure(self, comp1_id, comp2_id):
        """Create an xarray Dataset for a new set relationship.

        Notes:
            If the optional comp ids are provided, the set will be appended to StructureData.

        Args:
            comp1_id (Optional[int]): Component ID of the first feature in the set
            comp2_id (Optional[int]): Component ID of the second feature in the set

        Returns:
            xr.Dataset: The set Dataset
        """
        if comp1_id is None:
            coords = {'set_id': np.array([], np.int32)}
        else:
            coords = {'set_id': np.array([self.info.attrs['next_set_id']], np.int32)}
            self.info.attrs['next_set_id'] += 1
        data_vars = {
            'comp1_id': ('set_id', np.array([comp1_id] if comp1_id is not None else [], dtype=np.int32)),
            'comp2_id': ('set_id', np.array([comp2_id] if comp2_id is not None else [], dtype=np.int32)),
        }
        ds = xr.Dataset(data_vars=data_vars, coords=coords)
        if comp1_id is not None:  # Append the new row to the dataset if the comp ids have been provided
            self._sets = xr.concat([self.sets, ds], 'set_id')
        return ds

    def find_partner(self, comp_id):
        """Find the partner component ID of a given component ID.

        Args:
            comp_id (int): The component ID of the target component.

        Returns:
            Tuple[int, bool]: A tuple containing the partner component ID and a boolean indicating whether the
                components need to be swapped.
        """
        swap_comps = False
        # Check if the first feature in the set.
        mask = self.sets['comp1_id'] == comp_id
        ds = self.sets.where(mask, drop=True)
        if ds.sizes['set_id'] == 1:
            # Use int cast because sometimes it is a numpy data type
            partner_id = int(ds.comp2_id.data[0])
        else:  # Check if this is the second feature in the set.
            swap_comps = True  # Selected feature is the "second" feature in the set. Make sure GUI reflects this.
            mask = self.sets['comp2_id'] == comp_id
            ds = self.sets.where(mask, drop=True)
            # Use int cast because sometimes it is a numpy data type
            if ds.sizes['set_id'] == 1:
                partner_id = int(ds.comp1_id.data[0])
            else:
                partner_id = UNINITIALIZED_COMP_ID
        return partner_id, swap_comps

    def load_all_data(self):
        """Load all data from disk into memory."""
        _ = self.info
        _ = self.arcs
        _ = self.polygons
        _ = self.sets
        self._info.close()
        self._arcs.close()
        self._polygons.close()
        self._sets.close()

    def commit(self):
        """Save current in-memory component parameters to data file."""
        super().commit()  # Recreates the NetCDF file if vacuuming
        if self._arcs is not None:
            check_for_object_strings_dumb(self._arcs, STRUCTURE_STRING_VARIABLES)
            self._arcs.close()
            self._drop_h5_groups(['arcs'])
            # Convert the dataset to a dataframe and then back again, so we can write it out as NetCDF
            # Yes, this is stupid, and we should probably figure out why this is necessary, but for now we
            # just need something that works.
            df = self._arcs.to_dataframe()
            self._arcs = df.to_xarray()
            self._arcs.to_netcdf(self._filename, group='arcs', mode='a')
        if self._polygons is not None:
            check_for_object_strings_dumb(self._polygons, STRUCTURE_STRING_VARIABLES)
            self._polygons.close()
            self._drop_h5_groups(['polygons'])
            self._polygons.to_netcdf(self._filename, group='polygons', mode='a')
        if self._sets is not None:
            self._sets.close()
            self._drop_h5_groups(['sets'])
            self._sets.to_netcdf(self._filename, group='sets', mode='a')

    def vacuum(self):
        """Rewrite all data to a new file to reclaim disk space.

        All BC datasets that need to be written to the file must be loaded into memory before calling this method.
        """
        self.load_all_data()
        io_util.removefile(self._filename)
        self.commit()

    def update_file_paths(self):
        """Called before resaving an existing project.

        All referenced filepaths should be converted to relative from the project directory. Should already be stored
        in the component main file since this is a resave operation.

        Returns:
            (str): Message on failure, empty string on success
        """
        proj_dir = self.info.attrs['proj_dir']
        if not os.path.exists(proj_dir):
            return 'Unable to update selected file paths to relative from the project directory.'

        flush_data = False
        if self.arcs.sizes['comp_id'] > 0:
            for var in FILE_VARIABLES:
                flush_data |= self._update_table_files(self.arcs.variables[var], proj_dir, '')
        if self.polygons.sizes['comp_id'] > 0:
            for var in FILE_VARIABLES:
                flush_data |= self._update_table_files(self.polygons.variables[var], proj_dir, '')
        if flush_data:
            self.commit()
        return ''

    def copy_external_files(self):
        """Called when saving a project as a package. All components need to copy referenced files to the save location.

        Returns:
            (str): Message on failure, empty string on success
        """
        old_proj_dir = self.info.attrs['proj_dir']
        # Target directory for package save is three directories above the component UUID folder.
        target_dir = os.path.normpath(os.path.join(os.path.dirname(self._filename), '../../..'))
        # Copy the files referenced in tables
        for var in FILE_VARIABLES:
            self._copy_table_files(self.arcs.variables[var], old_proj_dir, target_dir)
            self._copy_table_files(self.polygons.variables[var], old_proj_dir, target_dir)
        # Wipe the stored project directory. Paths will be absolute in GUI until resave.
        self.info.attrs['proj_dir'] = ''
        return ''

    def update_proj_dir(self):
        """Called when saving a project for the first time or saving a project to a new location.

        All referenced filepaths should be converted to relative from the new project location. If the file path is
        already relative, it is relative to the old project directory. After updating file paths, update the project
        directory in the main file.

        Returns:
            (str): Message on failure, empty string on success.
        """
        try:
            old_proj_dir = self.info.attrs['proj_dir']  # Empty if first project save
            # Get the new project location. Three directories above the component UUID folder.
            comp_folder = os.path.dirname(self._filename)
            new_proj_dir = os.path.normpath(os.path.join(comp_folder, '../../..'))
            self.info.attrs['proj_dir'] = new_proj_dir
        except Exception:
            return 'There was a problem updating file paths to be relative from the project directory. Any selected ' \
                   'file paths will remain absolute.\n'

        for var in FILE_VARIABLES:
            self._update_table_files(self.arcs.variables[var], new_proj_dir, old_proj_dir)
        for var in FILE_VARIABLES:
            self._update_table_files(self.polygons.variables[var], new_proj_dir, old_proj_dir)
        self.commit()  # Save the updated project directory and referenced filepaths.
        return ''  # Don't report errors, leave that to model checks.
