"""BoreHoleData Class."""
__copyright__ = "(C) Copyright Aquaveo 2020"
__license__ = "All rights reserved"

# 1. Standard Python modules
import copy
import sys

# 2. Third party modules

# 3. Aquaveo modules
from xms.FhwaVariable.core_data.calculator.calcdata import CalcData
from xms.FhwaVariable.core_data.calculator.calculator_list import CalcOrVarlist
from xms.FhwaVariable.core_data.calculator.plot.plot_options import PlotOptions
from xms.FhwaVariable.core_data.variables.variable import Variable

# 4. Local modules
from xms.HydraulicToolboxCalc.gradations.gradation_layer_data import GradationLayerData
# from xms.HydraulicToolboxCalc.util.intersection import find_intersection_and_closest_index
from xms.HydraulicToolboxCalc.hydraulics.bridge_scour.gradations.bore_hole_calc import BoreHoleCalc


class BoreHoleData(CalcData):
    """Provides a class that will define the site data of a culvert barrel."""

    def __init__(self, app_data=None, model_name=None, project_uuid=None):
        """Initializes the Gradation Layer.

        Args:
            app_data (AppData): The application data (settings).
            model_name (str): The name of the model.
            project_uuid (str): The project UUID.
        """
        super().__init__(app_data=app_data, model_name=model_name, project_uuid=project_uuid)

        self.name = 'Borehole'
        self.type = 'Borehole'

        self.calculator = BoreHoleCalc()

        # Input
        self.input['Name'] = Variable('Name', "string", 'New soil set', [])

        self.input['Centerline'] = Variable('Centerline station', 'float', 0.0, [],
                                            limits=(-sys.float_info.max, sys.float_info.max), precision=2,
                                            unit_type=['length'], native_unit='ft', us_units=self.us_mid_length,
                                            si_units=self.si_mid_length)

        self.input['Layers'] = \
            Variable('Soil layers', 'calc_list', CalcOrVarlist(GradationLayerData, default_name='Soil layer',
                                                               default_plural_name='Soil layers',
                                                               select_one=False,
                                                               min_number_of_items=1, show_define_btn=False,
                                                               app_data=app_data,
                                                               model_name=model_name, project_uuid=project_uuid))
        self.station_var = Variable('Station', 'float_list', 0, [0.0], precision=2,
                                    limits=(-sys.float_info.max, sys.float_info.max), unit_type=['length'],
                                    native_unit='ft', us_units=self.us_mid_length, si_units=self.si_mid_length)
        self.elevation_var = Variable('Elevation', 'float_list', 0, [0.0], precision=2,
                                      limits=(-sys.float_info.max, sys.float_info.max), unit_type=['length'],
                                      native_unit='ft', us_units=self.us_mid_length, si_units=self.si_mid_length)

        self.input['Bottom elevation'] = Variable('Bottom elevation', 'float', 0.0, precision=2,
                                                  limits=(-sys.float_info.max, sys.float_info.max),
                                                  unit_type=['length'], native_unit='ft', us_units=self.us_mid_length,
                                                  si_units=self.si_mid_length)

        station_vs_elevation = {}
        station_vs_elevation['Station'] = self.station_var
        station_vs_elevation['Elevation'] = self.elevation_var

        input_dict = {}
        input_dict['Soil layers'] = station_vs_elevation
        self.input['Plot options'] = {}
        self.input['Plot options']['Soil layers'] = Variable('Plot options', 'class',
                                                             PlotOptions('Soil layers',
                                                                         input_dict=input_dict,
                                                                         show_series=True, app_data=app_data,
                                                                         model_name=model_name,
                                                                         project_uuid=project_uuid),
                                                             complexity=1)

        shear_vs_elevation = {}
        shear_vs_elevation['Shear'] = Variable(
            "Critical shear stress (τc)",
            'float_list',
            0.0,
            [],
            precision=4,
            unit_type=['stress'],
            native_unit='psf',
            us_units=[['psf']],
            si_units=[['pa']],
            note='The level of shear required to erode the soil layer')
        shear_vs_elevation['Elevation'] = self.elevation_var

        input_dict = {}
        plot_name = 'Shear stress vs elevation'
        input_dict[plot_name] = shear_vs_elevation
        self.input['Plot options'][plot_name] = Variable('Plot options', 'class',
                                                         PlotOptions(plot_name,
                                                                     input_dict=input_dict,
                                                                     show_series=True, app_data=self.app_data,
                                                                     model_name=self.model_name,
                                                                     project_uuid=self.project_uuid),
                                                         complexity=1)

        # Intermediate --
        self.compute_prep_functions.extend([self.check_name, self.update_plot_list])
        self.intermediate_to_copy.extend([
            'layers', 'current_layer', 'x_data', 'y_data', 'min_y', 'bcs_data', 'pier_data', 'bh_cl', 'max_elevation',
            'min_elevation', 'plot_names', 'plots',])
        self.compute_finalize_functions = [self.transfer_results]
        self.layers = {}
        self.current_layer = 0

        self.x_data = []
        self.y_data = []
        self.min_y = sys.float_info.max

        self.bcs_data = None
        self.pier_data = None
        self.bh_cl = None

        self.max_elevation = 0.0
        self.min_elevation = 0.0

        self.warnings = []
        self.results = {}
        # self.results['result'] = Variable('result', 'float_list', 0.0, [], precision=precision,
        #                                   unit_type=unit_type, native_unit=native_unit,
        #                                   us_units=us_units, si_units=si_units)
        self.results['Cross-section information'] = {}
        self.results['Cross-section information']['Highest elevation'] = Variable(
            'Highest elevation', 'float', 0.0, [], precision=4, unit_type=['length'],
            limits=(-sys.float_info.max, sys.float_info.max),
            native_unit='ft', us_units=self.us_mid_length, si_units=self.si_mid_length)
        self.results['Cross-section information']['Lowest elevation'] = Variable(
            'Lowest elevation', 'float', 0.0, [], precision=4, unit_type=['length'],
            limits=(-sys.float_info.max, sys.float_info.max),
            native_unit='ft', us_units=self.us_mid_length, si_units=self.si_mid_length)

        self.plot_names = ['Soil layers', 'Shear stress vs elevation', ]
        self.plots = {}
        for name in self.plot_names:
            self.plots[name] = {}
            self.plots[name]['Plot name'] = name
            self.plots[name]['Legend'] = "best"

    def check_name(self):
        """Check the name of the borehole."""
        _, name = self.get_setting('Borehole name')
        if self.input['Name'] == '' or self.input['Name'].get_val() == 'New soil set':
            self.input['Name'].set_val(self.name)

    def _get_can_compute(self):
        """Determine whether we have enough data to compute.

        Returns:
            True, if we can compute; otherwise, False
        """
        can_compute = True
        for layer in self.input['Layers'].value.item_list:
            result, _ = layer.get_can_compute()
            if not result:
                can_compute = False
                self.warnings.append(f"Cannot compute {layer.name}")
        return can_compute, self.warnings

    def get_val(self):
        """Computes and returns the results.

        Returns:
            self.results['result'] (variable): result of the computations.
        """
        self.compute_data()
        return self.results['result']

    def update_plot_list(self):
        """Updates the plot list."""
        num_items = self.input['Layers'].get_val().input['Number of items'].get_val()
        self.input['Plot options']['Soil layers'].get_val().input['Data series'].get_val().\
            input['Number of items'].set_val(num_items)

        self.input['Layers'].get_val().check_list_length()
        self.input['Plot options']['Soil layers'].get_val().input['Data series'].get_val().\
            check_list_length()

    def get_input_group(self, unknown=None):
        """Get the input group (for user-input).

        Returns
            input_vars (list of variables): input group of variables
        """
        input_vars = {}

        show_elevations = True

        self.input['Layers'].value.show_number = True
        self.input['Layers'].value.show_duplicate = True
        self.input['Layers'].value.show_delete = True

        num_layers = self.input['Layers'].get_val().input['Number of items'].get_val()
        for layer in self.input['Layers'].value.item_list:
            layer.show_elevations = show_elevations
            layer.num_layers = num_layers

        self.update_plot_list()

        input_vars = copy.deepcopy(self.input)

        if num_layers == 1:
            input_vars.pop('Bottom elevation')

        return input_vars

    def get_input_tab_group(self, unknown=None):
        """Returns the input group for the user interface.

        Args:
            unknown (string): variable that is unknown

        Returns:
            input_vars (list of variables): input group for the user interface's input table
        """
        input_vars = {}
        input_vars['Data table'] = self.get_input_group()
        input_vars['Plot options'] = {}
        input_vars['Plot options'] = copy.deepcopy(input_vars['Data table']['Plot options'])
        del input_vars['Data table']['Plot options']

        # remove plot options that are too complex
        if 'Plot options' in input_vars:
            names_to_pop = []
            sys_complexity = self.get_setting_var('Complexity')[1].value
            for plot_name in input_vars['Plot options']:
                if input_vars['Plot options'][plot_name].complexity > sys_complexity:
                    names_to_pop.append(plot_name)
            for name in names_to_pop:
                input_vars['Plot options'].pop(name)
            if not input_vars['Plot options']:
                input_vars.pop('Plot options')

        return input_vars

    def get_embedment_depth(self):
        """Returns the embedment depth.

        Returns:
            embedment depth (float): embedment depth
        """
        self.compute_data()
        return self.input['Embedment depth'].get_val()

    def get_results_tab_group(self, unknown=None):
        """Get the results tab group.

        Args:
            unknown (variable): unknown variable

        Returns:
            results_vars (list of variables): results group of variables
        """
        results_vars = copy.deepcopy(self.results)

        if not self.bcs_data or not self.bcs_data['x var'] or not self.bcs_data['y var']:
            results_vars.pop('Cross-section information')

        return {'Results': results_vars}

    def get_critical_shear_stress(self, index=None):
        """Returns the critical shear stress of the gradation.

        Args:
            index (int): index of the gradation layer

        Returns:
            critical shear stress (float): critical shear stress of the gradation
        """
        if index is None:
            index = self.current_layer
        if self.current_layer < len(self.input['Layers'].value.item_list):
            return self.input['Layers'].value.item_list[self.current_layer].get_critical_shear_stress()
        return None

    def get_d50(self, index=None):
        """Returns the D50 of the gradation.

        Returns:
            D50 (float): D50 of the gradation
        """
        if index is None:
            index = self.current_layer
        if self.current_layer < len(self.input['Layers'].value.item_list):
            return self.input['Layers'].value.item_list[self.current_layer].get_d50()
        return None

    def get_d84(self):
        """Returns the D84 of the gradation.

        Returns:
            D84 (float): D84 of the gradation
        """
        if self.allow_only_selected_gradations and self.input['gradation entry'].get_val() == \
                'enter only required gradations':
            return self.input['D84'].get_val()
        else:
            pass

    def transfer_results(self):
        """Transfer results from the calculator to the results variable."""
        layer_list = ['Cross-section information']
        for layer in self.input['Layers'].value.item_list:
            layer_list.append(layer.name)

        # Remove old results but don't apply (num layers changed, etc.)
        keys_to_pop = []
        for key in self.results:
            if key not in layer_list:
                keys_to_pop.append(key)
        for key in keys_to_pop:
            self.results.pop(key)

        extra_results_to_pop = ['Prior classification', 'Gradation classification', 'Calculate texture',
                                'Elevations computed']
        results_to_pop_if_zero = ['D15', 'D85', 'D95', 'Cu',
                                  'Cc', 'Graded', 'Classification', 'Texture',
                                  'Sediment gradation coefficient (σ)', 'Dimensionless grain size diameter (D*)',
                                  'Shields number (ks)']

        for key in self.calculator.results:
            for layer in self.input['Layers'].value.item_list:
                if layer.name in key:
                    self.results[key] = copy.deepcopy(layer.results)
                    self._remove_extra_keys(self.results[key], extra_results_to_pop, if_zero=False)
                    self._fill_results_dict_recursive(self.results[key], self.calculator.results[key])
                    if 'Elevations' in layer.results and 'Elevations computed' in layer.results['Elevations'] and \
                            layer.results['Elevations']['Elevations computed'].get_val():
                        layer.input['Bottom elevation'].set_val(layer.results['Elevations'][
                            'Bottom elevation'].get_val())
                    self._remove_extra_keys(self.results[key], results_to_pop_if_zero, if_zero=True)
                    break

    def _remove_extra_keys(self, results_dict, key_dict, if_zero=False):
        """Remove extra keys from the results dictionary."""
        zero_tol = self.get_setting('Zero tolerance')[1]

        for key in key_dict:
            self._remove_extra_keys_recursive(key, results_dict, if_zero, zero_tol)

    def _remove_extra_keys_recursive(self, key, results_dict, if_zero=False, zero_tol=0.0001):
        """Recursively remove extra keys from the results dictionary."""
        if not isinstance(results_dict, dict):
            return
        if key in results_dict:
            if if_zero:
                if results_dict[key].type == 'float' and results_dict[key].get_val() <= zero_tol:
                    results_dict.pop(key)
                elif results_dict[key].type == 'string' and results_dict[key].get_val() == '' or \
                        results_dict[key].type == 'string' and results_dict[key].get_val() is False:
                    results_dict.pop(key)
            elif not if_zero:
                results_dict.pop(key)

        empty_keys = []
        for subkey in results_dict:
            self._remove_extra_keys_recursive(key, results_dict[subkey], if_zero)
            if isinstance(results_dict[subkey], dict) and len(results_dict[subkey]) == 0:
                empty_keys.append(subkey)
        for empty_key in empty_keys:
            results_dict.pop(empty_key)
