"""CalcData for performing Pier Scour calculations."""
__copyright__ = "(C) Copyright Aquaveo 2024"
__license__ = "All rights reserved"

# 1. Standard Python modules
import copy
import sys

# 2. Third party modules

# 3. Aquaveo modules
from xms.FhwaVariable.core_data.calculator.calculator import Calculator
# from xms.FhwaVariable.core_data.calculator.calcdata import CalcData
# from xms.FhwaVariable.core_data.variables.variable import Variable

# 4. Local modules


class ScourBaseCalc(Calculator):
    """A class that defines a pier scour at a bridge contraction."""

    def __init__(self):
        """Initializes the GVF calculator.
        """
        super().__init__()

        self.gradations = None

        self.approach_gradation_dict = {}
        self.contracted_gradation_dict = {}

        self.d50 = 0.0
        self.d50_ft = 0.0
        self.upstream_d50 = 0.0
        self.upstream_d50_ft = 0.0
        self.critical_shear_stress = 0.0

        # Intermediate
        self.scour_depth = 0.0

        self.bh_dict_names = {}

        self.centerline = 0.0
        self.centerline_streambed = 0.0
        self.channel_location = 'main'
        self.scour_reference_point = None
        self.bh_dict = {}
        self.bh_uuid = None
        self.bh = None
        self.bh_layers = None
        self.layer_index = None

        self.surface_d50 = 0.0

        self.is_computing_shear_decay = False

    def set_intermediate_d50(self):
        """Set the intermediate D50 values."""
        if 'Results' not in self.results:
            self.results['Results'] = {}
        self.results['Results']['Approach D50'] = self.upstream_d50

        self.results['Results']['Contracted D50'] = self.d50

        self.results['Results']['Critical shear stress (τc)'] = self.critical_shear_stress

    def update_gradation_lists(self):
        """Updates the gradation lists and dicts."""
        if self.gradations is None:
            self.set_intermediate_d50()
            return

        if not hasattr(self.gradations, 'input_dict'):
            self.gradations.input_dict, self.gradations.plot_dict = self.gradations.prepare_input_dict()

        # Contracted gradation
        self.contracted_bhs = None
        self.contracted_bh = None
        if self.gradations.input_dict['calc_data']['Contracted gradation input options'] == 'Single borehole':
            self.contracted_bh = self.gradations.input_dict['calc_data']['Single borehole']['calculator']
            self.contracted_bh.input_dict = self.gradations.input_dict['calc_data']['Single borehole']
            if 'Gradations' in self.gradations.plot_dict and 'Single borehole' in self.gradations.plot_dict[
                    'Gradations']:
                self.contracted_bh.plot_dict = self.gradations.plot_dict['Gradations']['Single borehole']
            elif 'Single borehole' in self.gradations.plot_dict:
                self.contracted_bh.plot_dict = self.gradations.plot_dict['Single borehole']
        elif self.gradations.input_dict['calc_data']['Contracted gradation input options'] == 'Multiple boreholes':
            self.contracted_bhs = self.gradations.input_dict['calc_data']['Boreholes']['calculator']
            _, sel_bh = self.get_data('Selected borehole')
            if sel_bh == 'None' or sel_bh == -1:
                self.contracted_bh = None
        else:  # if self.gradations.input_dict['calc_data']['Contracted gradation input options'] == 'Single D50':
            self.input_dict['Contracted D50'] = self.gradations.input_dict['calc_data']['Contracted D50']
            self.d50 = self.input_dict['Contracted D50']
            self.critical_shear_stress = 0.0

        # Approach gradation
        self.approach_gradation = None
        if self.gradations.input_dict['calc_data']['Approach gradation input options'] == 'Single D50':
            self.input_dict['calc_data']['Approach D50'] = self.gradations.input_dict['calc_data']['Approach D50']
            self.upstream_d50 = self.input_dict['calc_data']['Approach D50']
        elif self.gradations.input_dict['calc_data']['Approach gradation input options'] == 'Gradations':
            if self.gradations.input_dict['calc_data']['Approach gradations']['Approach locations'] == \
                    'Single gradation for entire cross-section':
                self.approach_gradation = self.gradations.input_dict['calc_data']['Approach gradations'][
                    'Approach layer']
            else:
                if self.input_dict['calc_data']['Selected approach location'] == 'left overbank':
                    self.approach_gradation = self.gradations.input_dict['calc_data']['Approach gradations'][
                        'Left overbank layer']
                elif self.input_dict['calc_data']['Selected approach location'] == 'right overbank':
                    self.approach_gradation = self.gradations.input_dict['calc_data']['Approach gradations'][
                        'Right overbank layer']
                else:
                    self.approach_gradation = self.gradations.input_dict['calc_data']['Approach gradations'][
                        'Approach layer']
        else:  # self.gradations.input['Approach gradation input options'].get_val() == 'None':
            self.input_dict['calc_data']['Approach D50'] = 0.0
            self.approach_gradation = None

        # # Allow the user to select None on the approach gradation
        self.set_closest_borehole()
        self.set_layers()
        self.set_d50_for_approach_gradation()
        self.set_intermediate_d50()

        return self.d50, self.upstream_d50, self.critical_shear_stress

    def set_closest_borehole(self):
        """Get the closest contracted D50."""
        # Update the selectable boreholes
        if self.contracted_bhs is None:
            self.input_dict['calc_data']['Selected borehole'] = None
            return
        self.bh_dict = {}
        self.bh_dict_names = {}
        self.bh_dict_names['None'] = 'None'
        for bh_index in self.gradations.input_dict['calc_data']['Boreholes']['Boreholes']:
            borehole = self.gradations.input_dict['calc_data']['Boreholes']['Boreholes'][bh_index]
            if 'Gradations' in self.gradations.plot_dict:
                borehole['calculator'].plot_dict = self.gradations.plot_dict['Gradations']['Boreholes']['Boreholes'][
                    bh_index]
            self.bh_dict[borehole['calculator'].uuid] = borehole
            self.bh_dict_names[borehole['calculator'].uuid] = borehole['Name']
        # TODO: Fix the borehole selection in scour calculators
        # self.input['Selected borehole'].value_options = self.bh_dict_names
        if not self.input_dict['calc_data']['Selected borehole'] in self.bh_dict:
            self.input_dict['calc_data']['Selected borehole'] = 'None'
        if self.input_dict['calc_data']['Selected borehole'] != 'None':
            sel_bh_index = self.input_dict['calc_data']['Selected borehole']
            self.contracted_bh = self.bh_dict[sel_bh_index]['calculator']
            # Get the D50 at the surface (to set as the approach D50 for contraction scour)
            # Get the D50 at the eroded surface for starting D50 on Contracted section
            layer_index, _, _, _ = self.get_layer_index_from_elevation(self.centerline_streambed)
            if layer_index is None:
                layer_index = 0
            self.contracted_bh.current_layer = layer_index
            self.surface_d50 = self.contracted_bh.get_d50(0)
            self.d50 = self.contracted_bh.get_d50(0)
            d50 = self.contracted_bh.get_d50()
            self.input_dict['calc_data']['Contracted D50'] = d50
            self.critical_shear_stress = self.contracted_bh.get_critical_shear_stress()
            self.input_dict['calc_data']['Critical shear stress (τc)'] = self.critical_shear_stress
            return

        closest_borehole = None
        closest_d50_distance = sys.float_info.max
        for bh_index in self.gradations.input_dict['calc_data']['Boreholes']['Boreholes']:
            borehole = self.gradations.input_dict['calc_data']['Boreholes']['Boreholes'][bh_index]
            distance = abs(self.centerline - borehole['Centerline'])
            if distance < closest_d50_distance:
                closest_d50_distance = distance
                closest_borehole = borehole
        if closest_borehole is None:
            return
        self.input_dict['calc_data']['Selected borehole'] = closest_borehole['calculator'].uuid
        self.contracted_bh = self.bh_dict[self.input_dict['calc_data']['Selected borehole']]['calculator']
        self.contracted_bh.input_dict = copy.copy(self.input_dict)
        self.contracted_bh.input_dict['calc_data'] = copy.copy(self.bh_dict[self.input_dict['calc_data'][
            'Selected borehole']])

    def set_layers(self):
        """Set the layers for the selected borehole."""
        if self.contracted_bh == 'None' or self.contracted_bh is None:
            return
        bh_input_dict = copy.deepcopy(self.input_dict)
        if 'calc_data' in self.contracted_bh.input_dict:
            bh_input_dict['calc_data'] = self.contracted_bh.input_dict['calc_data']
        else:
            bh_input_dict['calc_data'] = self.contracted_bh.input_dict

        if self.bcs_data is not None:
            self.contracted_bh.bcs_data = self.bcs_data
            self.contracted_bh.pier_data = self.pier_data
            self.contracted_bh.bh_cl = self.bh_cl
            self.contracted_bh.abutment_data = self.abutment_data

        _, _, _, _, _ = self.contracted_bh.compute_data(input_dict=bh_input_dict, plot_dict=self.plot_dict)

        self.bh_layers = self.contracted_bh.input_dict['calc_data']['Layers']

        for layer_index in self.bh_layers:
            self.bh_layers[layer_index]['calculator'].compute_data_with_subdict(self.input_dict,
                                                                                self.bh_layers[layer_index])

        # TODO: Fix transferring the plotting
        # for i in range(len(self.bh_layers)):
        #     self.bh_layers[i].input['Plot options']['Gradation'].value.input['Data series'].value.item_list[0].input[
        #         'Plot series'].value.item_list[0] = self.contracted_bh.input['Plot options'][
        #             'Soil layers'].value.input['Data series'].value.item_list[0].input['Plot series'].value\
        #                   .item_list[i]
        self.d50 = 0.0
        self.critical_shear_stress = 0.0
        # if 0 in self.bh_layers:
        #     d50 = self.bh_layers[0]['calculator'].d50
        #     self.critical_shear_stress = self.bh_layers[0]['calculator'].tau
        self.d50 = self.contracted_bh.get_d50(0)
        self.critical_shear_stress = self.contracted_bh.get_critical_shear_stress(0)
        self.input_dict['calc_data']['Contracted D50'] = self.d50
        self.input_dict['calc_data']['Critical shear stress (τc)'] = self.critical_shear_stress

    def set_d50_for_approach_gradation(self):
        """Get the D50 for the approach gradation."""
        if self.approach_gradation is None:
            return
        _, approach_options = self.get_data('Approach gradation input options', input_dict=self.gradations.input_dict)
        if approach_options != 'Gradations':
            return
        if self.approach_gradation['calculator'].d50 <= 0.0:
            self.approach_gradation['calculator'].compute_data()
        self.upstream_d50 = self.approach_gradation['calculator'].d50
        self.input_dict['calc_data']['Approach D50'] = self.upstream_d50

    def set_angle_of_repose(self):
        """Returns the angle of repose."""
        if self.gradations:
            _, input_options = self.get_data('Contracted gradation input options',
                                             input_dict=self.gradations.input_dict)
            if input_options == 'Single borehole':
                return
        else:
            return

        if self.bh_layers is None:
            return

        if self.layer_index is None:
            self.layer_index = 0

        if self.layer_index >= len(self.bh_layers):
            self.layer_index = len(self.bh_layers) - 1

        if 'Scour hole geometry' in self.input_dict['calc_data']:
            self.input_dict['calc_data']['Scour hole geometry']['Angle of repose'] = self.bh_layers[
                self.layer_index]['Angle of repose (Θ)']

    def compute_scour_hole_geometry(self, pier_width=0.0, is_pier_scour_hole=False, input_dict=None):
        """Computes the scour hole geometry based on the scour depth.

        Returns:
            None
        """
        if input_dict is None:
            input_dict = self.input_dict

        self.set_angle_of_repose()

        # Update the scour depth in the input dictionary
        self.input_dict['calc_data']['Scour hole geometry']['Scour depth'] = self.scour_depth

        self.input_dict['calc_data']['Scour hole geometry']['calculator'].pier_width = pier_width
        self.input_dict['calc_data']['Scour hole geometry']['calculator'].is_pier_scour_hole = is_pier_scour_hole

        _, result, results, warnings, _ = self.input_dict['calc_data']['Scour hole geometry'][
            'calculator'].compute_data_with_subdict(self.input_dict,
                                                    self.input_dict['calc_data']['Scour hole geometry'],
                                                    self.plot_dict)
        if result:
            self.results['Bottom width'] = results['Bottom width']
            self.results['Top width'] = results['Top width']
            self.warnings.update(warnings)

    def compute_shear(self, wse, elevation, mannings_n, unit_q):
        """Compute the contraction curve point for a given depth.

        Args:
            wse_elevation (float): The elevation of the water surface.
            elevation (float): The elevation at which to compute the contraction curve.
            mannings_n (float): The Manning's n value for the soil layer.
            unit_q (float): The unit discharge.

        Returns:
            float: The shear stress at the given elevation
        """
        _, gamma_w = self.get_data('Unit weight of water (γw)', 62.4)
        _, k = self.get_data('Manning constant')
        depth = wse - elevation
        if depth <= 0.0:
            return 0.0
        shear = gamma_w * ((depth) ** (-7.0 / 3.0)) * ((mannings_n * unit_q / k) ** 2.0)
        return shear

    def compute_alpha_beta_for_pier_shape(self):
        """Compute the alpha and beta values for the pier shape.

        Returns:
            tuple of float: The alpha and beta values
        """
        shear_alpha = 4.37
        shear_beta = 1.33

        result, alpha = self.get_data('Alpha value for local shear decay')
        if result:
            shear_alpha = alpha
        result, beta = self.get_data('Beta value for local shear decay')
        if result:
            shear_beta = beta

        return shear_alpha, shear_beta

    def get_layer_index_from_elevation(self, elevation):
        """Get the index of the layer that contains the elevation.

        Args:
            elevation (float): The elevation to search for.

        Returns:
            tuple of int: The index of the layer, the top elevation of the layer, the bottom elevation of the layer,
            and the critical shear stress of the layer.
        """
        layer_index = -1
        cur_layer_top = -sys.float_info.max
        cur_layer_bottom = sys.float_info.max

        _, input_options = self.get_data('Contracted gradation input options', input_dict=self.gradations.input_dict)
        if input_options == 'Single D50' or self.contracted_bh is None or self.bh_layers is None:
            return None, None, None, None

        while cur_layer_top is not None and not cur_layer_top >= elevation > cur_layer_bottom and \
                layer_index + 1 < len(self.bh_layers):
            layer_index += 1
            cur_layer = self.bh_layers[layer_index]
            cur_layer_top = cur_layer['Top elevation']
            cur_layer_bottom = cur_layer['Bottom elevation']
            cur_critical_shear = cur_layer['calculator'].tau

        return layer_index, cur_layer_top, cur_layer_bottom, cur_critical_shear

    def get_next_layer_details(self, layer_index, wse, unit_q):
        """Get the next layer in the borehole.

        Args:
            layer_index (int): The index of the current layer.
            wse (float): The water surface elevation.
            unit_q (float): The unit discharge.

        Returns:
            tuple of float: The Manning's n, the shear stress, and the critical shear stress of the next layer.
        """
        if layer_index + 1 < len(self.bh_layers):
            lower_layer_n = self.bh_layers[layer_index + 1]['calculator'].manning_n
            lower_layer_top = self.bh_layers[layer_index + 1]['Top elevation']
            lower_layer_shear = self.compute_shear(wse, lower_layer_top, lower_layer_n, unit_q)
            lower_layer_critical = self.bh_layers[layer_index + 1]['calculator'].tau
            return lower_layer_n, lower_layer_shear, lower_layer_critical
        return None, None, None

    def set_layer_plot_data(self, decay_x, decay_y, marker_x, marker_y, y_max=None, use_csu=None):
        """Set the data for the layer plot.

        Args:
            decay_x (list): The x values for the decay curve.
            decay_y (list): The y values for the decay curve.
            marker_x (list): The x values for the markers.
            marker_y (list): The y values for the markers.
            y_max (float): The maximum depth.
            use_csu (bool): Whether to use the CSU equation or the Hager equation for ymax.
        """
        if 'Shear decay' not in self.plot_dict:
            return

        # Bring the soil shear plot data into our current plot
        if self.gradations.input_dict['calc_data']['Contracted gradation input options'] == 'Single borehole':
            if 'Gradations' in self.gradations.plot_dict and 'Single borehole' in self.gradations.plot_dict[
                    'Gradations']:
                plot_dict = self.gradations.plot_dict['Gradations']['Single borehole']
            elif 'Single borehole' in self.gradations.plot_dict:
                plot_dict = self.gradations.plot_dict['Single borehole']
        elif self.gradations.input_dict['calc_data']['Contracted gradation input options'] == 'Multiple boreholes':
            self.gradations.input_dict['calc_data']['Boreholes']
            plot_dict = self.gradations.plot_dict['Single borehole']
        shear_elev_plot_dict = self.contracted_bh.compute_data(plot_dict=plot_dict)[4]

        if 'Shear stress vs elevation' in shear_elev_plot_dict:
            series_index = 3
            if 'series' not in self.plot_dict['Shear decay']:
                self.plot_dict['Shear decay']['series'] = {}
            for layer_series in shear_elev_plot_dict['Shear stress vs elevation']['series']:
                self.plot_dict['Shear decay']['series'][series_index] = copy.deepcopy(
                    shear_elev_plot_dict['Shear stress vs elevation']['series'][layer_series])
                series_index += 1

        sd_series, _ = self.get_plot_subdict_and_key_by_name('Shear decay', 'series', 'Shear decay')
        sd_points, _ = self.get_plot_subdict_and_key_by_name('Shear decay', 'points', 'Shear decay')
        ymax_lines, _ = self.get_plot_subdict_and_key_by_name('Y max', 'lines', 'Shear decay')

        if sd_series is not None and 'x var' in sd_series:
            sd_series['x var'].set_val(decay_x)
        if sd_series is not None and 'y var' in sd_series:
            sd_series['y var'].set_val(decay_y)
        if sd_points is not None and 'x var' in sd_points:
            sd_points['x var'].set_val(marker_x)
        if sd_points is not None and 'y var' in sd_points:
            sd_points['y var'].set_val(marker_y)

        if y_max is not None:
            name = 'Y max'
            if use_csu:
                name = 'Y max (CSU)'
            elif use_csu is False:
                name = 'Y max (Hager)'
            if ymax_lines is not None:
                ymax_lines['Label'] = name
                ymax_lines['Line intercepts'] = [y_max]
                ymax_lines['Line alignment'] = 'horizontal'
