"""BoreHoleData Class."""
__copyright__ = "(C) Copyright Aquaveo 2020"
__license__ = "All rights reserved"

# 1. Standard Python modules
import copy

# 2. Third party modules

# 3. Aquaveo modules
from xms.FhwaVariable.core_data.calculator.calculator import Calculator
# from xms.FhwaVariable.core_data.calculator.plot.plot_options import PlotOptions
from xms.FhwaVariable.core_data.variables.variable import Variable

# 4. Local modules
# from xms.HydraulicToolboxCalc.gradations.gradation_layer_data import GradationLayerData
from xms.HydraulicToolboxCalc.util.intersection import find_intersection_and_closest_index
from xms.HydraulicToolboxCalc.util.list_utils import remove_nans


class BoreHoleCalc(Calculator):
    """Provides a class that will define the site data of a culvert barrel."""

    def _compute_data(self):
        """Computes the data possible; stores results in self.

        Returns:
            bool: True if successful
        """
        if self.bcs_data and self.bcs_data['x var'] and self.bcs_data['y var']:
            stations = self.bcs_data['x var'].get_val()
            elevations = self.bcs_data['y var'].get_val()

            _, null_data = self.get_data('Null data')
            valid_elevations = [elev for elev in elevations if elev != null_data]
            valid_stations = [stat for stat in stations if stat != null_data]

            self.x_data = valid_stations
            self.y_data = valid_elevations

            highest_elevation = max(valid_elevations)
            lowest_elevation = min(valid_elevations)

            if 'Cross-section information' not in self.results:
                self.results['Cross-section information'] = {}
            self.results['Cross-section information']['Highest elevation'] = highest_elevation
            self.results['Cross-section information']['Lowest elevation'] = lowest_elevation

        if (self.y_data is not None and self.x_data is not None
                and len(self.y_data) > 0 and len(self.x_data) > 0
                and len(self.y_data) == len(self.x_data)):
            self.max_elevation = max(self.y_data)
            self.min_elevation = min(self.y_data)

            station_1 = min(self.x_data)
            station_2 = max(self.x_data)
        else:
            station_1 = 0.0
            station_2 = 10.0

        if station_1 == station_2:
            station_1 = station_2 - 10.0

        self.stations = [station_1, station_2, station_2, station_1, station_1]

        self.layers = {}

        update_colors = True

        cs_stations, cs_elevations = [], []
        if self.bcs_data and self.bcs_data['x var'] and self.bcs_data['y var']:
            cs_stations, cs_elevations, _ = remove_nans(self.bcs_data['x var'].get_val(), self.bcs_data[
                'y var'].get_val(), null_data)

        if len(cs_stations) == 0 or len(cs_elevations) == 0:
            max_elevation = None
            thalweg_elevation = None
        else:
            max_elevation = max(cs_elevations)
            thalweg_elevation = min(cs_elevations)
        if len(self.input_dict['calc_data']['Layers']) == 1:
            # If we only have one layer, default the elevations to the cross-section
            self.input_dict['calc_data']['Layers'][0]['Top elevation'] = max_elevation
            self.input_dict['calc_data']['Bottom elevation'] = thalweg_elevation

        plot_name = 'Soil layers'
        if 'Soil layers' in self.plot_dict:
            plot_options = self.plot_dict['Soil layers']
        else:
            plot_options = None
            # return False

        input_index = 0
        # num_items = self.input['Layers'].value.input['Number of items'].get_val()
        for index in range(len(self.input_dict['calc_data']['Layers'])):
            layer = self.input_dict['calc_data']['Layers'][index]
            can_compute, result, layer_results, layer_warnings, layer_plot_dict = layer['calculator'].\
                compute_data_with_subdict(self.input_dict, self.input_dict['calc_data']['Layers'][index],
                                          self.plot_dict)
            layer['Results'] = layer_results
            layer['can_compute'] = can_compute
            self.warnings.update(layer_warnings)
            self.plot_dict.update(layer_plot_dict)
            if plot_options:
                if input_index not in plot_options['series']:
                    plot_options['series'][input_index] = {}
                layer_options = plot_options['series'][input_index]
                color, fill_color, pattern = self.assign_color_based_on_gradation(layer)
                layer_options['Name'] = layer['Name']
                if update_colors and color is not None:
                    layer_options['Line color'] = color
                    layer_options['Fill color'] = fill_color
                    layer_options['Fill pattern'] = pattern
            top_elevation = layer['Top elevation']
            if top_elevation in self.layers:
                new_elevation = top_elevation + 0.1
                while new_elevation in self.layers:
                    new_elevation += 0.1
                self.warnings[f"Dup {top_elevation} and {layer['Name']} to {new_elevation}"] = \
                    f"Duplicate top elevation found: {top_elevation}; elevation for {layer['Name']} modified to " \
                    f"{new_elevation}"
                top_elevation = new_elevation
            self.layers[top_elevation] = (layer, input_index)
            input_index += 1
        self.layers = dict(sorted(self.layers.items(), key=lambda item: item[0], reverse=True))

        # Clear results, but maintain the cross-section information
        sorted_index = 0
        key_list = list(self.layers.keys())
        for key in key_list:
            new_key = f'Layer {sorted_index + 1}: {self.layers[key][0]["Name"]}'
            self.results[new_key] = {}

            top_elev = self.layers[key][0]['Top elevation']
            if top_elev is None:
                continue
            self.results[new_key]['Top elevation'] = top_elev
            bottom_elev = self.min_y
            if sorted_index < len(self.layers) - 1:
                bottom_elev = self.layers[key_list[sorted_index + 1]][0]['Top elevation']
                self.results[new_key]['Bottom elevation'] = bottom_elev
                self.layers[key][0]['Bottom elevation'] = bottom_elev
                self.layers[key][0]['calculator'].results['Top elevation'] = top_elev
                self.layers[key][0]['calculator'].results['Bottom elevation'] = bottom_elev
                self.layers[key][0]['calculator'].bottom_elev = bottom_elev
                self.layers[key][0]['Results']['Elevations'] = {}
                self.layers[key][0]['Results']['Elevations']['Elevations computed'] = True
                self.layers[key][0]['Results']['Elevations']['Top elevation'] = top_elev
                self.layers[key][0]['Results']['Elevations']['Bottom elevation'] = bottom_elev
            elif sorted_index == len(self.layers) - 1:
                _, bottom_elev = self.get_data('Bottom elevation')
                self.layers[key][0]['Bottom elevation'] = bottom_elev
                self.results[new_key]['Bottom elevation'] = bottom_elev
                self.layers[key][0]['calculator'].results['Top elevation'] = top_elev
                self.layers[key][0]['calculator'].results['Bottom elevation'] = bottom_elev
                self.layers[key][0]['Results']['Elevations'] = {}
                self.layers[key][0]['Results']['Elevations']['Elevations computed'] = True
                self.layers[key][0]['Results']['Elevations']['Top elevation'] = top_elev
                self.layers[key][0]['Results']['Elevations']['Bottom elevation'] = bottom_elev

            if top_elev <= bottom_elev:
                bottom_elev = top_elev - 10.0

            self.elevations = [bottom_elev, bottom_elev, top_elev, top_elev, bottom_elev]

            can_compute = self.layers[key][0]['can_compute']
            if can_compute:
                if plot_options:
                    layer_options = plot_options['series'][self.layers[key][1]]
                    self.results[new_key].update(self.layers[key][0]['calculator'].results)
                    layer_options['Name'] = self.layers[key][0]['Name']
                    if 'x var' not in layer_options:
                        if self.bcs_data and self.bcs_data['x var']:
                            layer_options['x var'] = copy.deepcopy(self.bcs_data['x var'])
                        else:
                            layer_options['x var'] = Variable('x_var', 'float_list', 0.0, [])
                    if 'y var' not in layer_options:
                        if self.bcs_data and self.bcs_data['y var']:
                            layer_options['y var'] = copy.deepcopy(self.bcs_data['y var'])
                        else:
                            layer_options['y var'] = Variable('y_var', 'float_list', 0.0, [])
                    layer_options['x var'].set_val(self.stations)
                    layer_options['x name'] = 'Station'
                    layer_options['y var'].set_val(self.elevations)
                    layer_options['y name'] = 'Elevations'

                if self.layers[key][0]['Soil type'] == 'Cohesive':
                    self.results[new_key]['Critical shear stress (τc)'] = \
                        self.layers[key][0]['Critical shear stress (τc)']
                else:
                    self.results[new_key]['D50'] = self.layers[key][0]['calculator'].results['Gradation']['D50']
                self.results[new_key]['Classification'] = self.layers[key][0]['calculator'].results['Gradation'][
                    'Classification']
            sorted_index += 1

        # Create more plots:
        # critical shear stress vs elevation
        plot_name = 'Shear stress vs elevation'

        # shear_plot_options = self.input['Plot options'][plot_name].get_val()
        if plot_options:
            shear_plot_dict = self.plot_dict[plot_name]

            key_list = list(self.layers.keys())
            for key in key_list:
                cur_index = self.layers[key][1]
                if cur_index not in shear_plot_dict['series'] and 0 in shear_plot_dict['series']:
                    shear_plot_dict['series'][cur_index] = copy.deepcopy(shear_plot_dict['series'][0])
                layer_dict = shear_plot_dict['series'][cur_index]
                shear = self.layers[key][0]['calculator'].tau
                for layer_key in layer_dict:
                    if layer_key not in self.plot_dict['Soil layers']['series'][cur_index] or \
                            layer_key in ['x var', 'X axis']:
                        continue
                    layer_dict[layer_key] = self.plot_dict['Soil layers']['series'][cur_index][layer_key]
                if layer_dict['x var'] and len(layer_dict['x var'].value_options) > 0:
                    values = layer_dict['x var'].value_options  # Get the list of values
                    low_number = min(values)  # Find the low number (e.g., 0)
                    high_number = max(values)  # Find the high number (e.g., 10)
                    # Replace low numbers with 0 and high numbers with the value of `shear`
                    new_values = [0.0 if v == low_number else shear if v == high_number else v for v in values]
                    layer_dict['x var'].set_val(new_values)
                else:
                    layer_dict['x var'].set_val([0.0, shear, shear, 0.0, 0.0])

            # layers in the cross-section (include borehole cl)
            cs_name = f"{self.input_dict['calc_data']['Name']} layers plotted along cross-section"
            self.plot_names.append(cs_name)
            self.plot_dict[cs_name] = copy.deepcopy(self.plot_dict['Soil layers'])
            cs_plot_dict = self.plot_dict[cs_name]
            if self.bcs_data is None or self.bcs_data['x var'] is None or self.bcs_data['y var'] is None:
                return True

            _, null_data = self.get_data('Null data')
            clean_cs_stations, clean_cs_elevations, _ = remove_nans(cs_stations, cs_elevations, null_data)
            for key in key_list:
                layer_dict = cs_plot_dict['series'][self.layers[key][1]]

                shear = self.layers[key][0]['calculator'].tau
                if 'x var' in layer_dict and layer_dict['x var']:
                    values = layer_dict['y var'].value_options  # Get the list of values
                    low_elev = min(values)  # Find the low number (e.g., 0)
                    high_elev = max(values)  # Find the high number (e.g., 10)
                    new_x, new_y = self.get_soil_layer_fit_to_cross_section(high_elev, low_elev,
                                                                            clean_cs_stations, clean_cs_elevations)
                    layer_dict['x var'].set_val(new_x)
                    layer_dict['y var'].set_val(new_y)

            # Add bridge cross-section and pier data to the plot options
            num_layers = len(cs_plot_dict['series'])
            if self.bcs_data is not None:
                cs_plot_dict['series'][num_layers] = copy.deepcopy(self.bcs_data)
            if self.pier_data is not None:
                cs_plot_dict['series'][num_layers + 1] = copy.deepcopy(self.pier_data)
            if self.bh_cl is not None:
                cs_plot_dict['series'][num_layers + 2] = copy.deepcopy(self.bh_cl)

            self.plot_names.append(cs_name)
            self.plots[cs_name] = {}
            self.plots[cs_name]['Plot name'] = cs_name
            self.plots[cs_name]['Legend'] = "best"
        return True

    def get_soil_layer_fit_to_cross_section(self, top_elevation: float, bottom_elevation: float,
                                            cs_stations: list[float], cs_elevations: list[float]
                                            ) -> tuple[list[float], list[float]]:
        """
        Calculate the resulting station and elevation values for a soil layer fit to a cross-section.

        Args:
            top_elevation (float): The top elevation of the soil layer.
            bottom_elevation (float): The bottom elevation of the soil layer.
            cs_stations (list[float]): List of station values (x) for the cross-section.
            cs_elevations (list[float]): List of elevation values (y) for the cross-section.

        Returns:
            tuple: Two lists - station values (x) and elevation values (y).
        """
        if len(cs_stations) != len(cs_elevations) or len(cs_stations) < 2:
            return [], []

        # Case 1: Soil layer is completely above the cross-section
        if bottom_elevation >= max(cs_elevations):
            return [], []

        # Case 3: Soil layer is completely below the cross-section
        if top_elevation <= min(cs_elevations):
            min_station = min(cs_stations)
            max_station = max(cs_stations)
            return [min_station, max_station, max_station, min_station, min_station], \
                [bottom_elevation, bottom_elevation, top_elevation, top_elevation, bottom_elevation]

        # Initialize resulting x and y lists
        x_values = []
        y_values = []

        min_station = min(cs_stations)
        max_station = max(cs_stations)

        # Check if stationing is backwards (logic only works left-to-right)
        if cs_stations[0] > cs_stations[-1]:
            cs_stations.reverse()
            cs_elevations.reverse()

        # Case 2: Soil layer crosses the cross-section
        cs_above_top = False
        cs_below_bottom = False
        if top_elevation > cs_elevations[0]:
            x_values.append(cs_stations[0])
            y_values.append(cs_elevations[0])
        else:
            x_values.append(min_station)
            y_values.append(top_elevation)
            cs_above_top = True

        top_elev_x = [min_station, max_station]
        top_elev_y = [top_elevation, top_elevation]

        bottom_elev_x = [min_station, max_station]
        bottom_elev_y = [bottom_elevation, bottom_elevation]

        cur_poly_x = x_values[0]
        cur_poly_y = y_values[0]
        _, null_data = self.get_data('Null data')

        # Iterate over the cross-section segments
        for i in range(1, len(cs_stations)):
            x1, y1 = cs_stations[i], cs_elevations[i]
            # x2, y2 = cs_stations[i + 1], cs_elevations[i + 1]

            if cs_above_top:
                if y1 > top_elevation:
                    continue
                else:
                    cs_above_top = False
                    # Find the intersection point with the top of soil layer
                    result, int_x, int_y, _, _ = find_intersection_and_closest_index(
                        cs_stations, cs_elevations, top_elev_x, top_elev_y, i - 1, 0)
                    if result:
                        # Coverage tests couldn't find - I think unreachable
                        # if cur_poly_x is None:
                        #     cur_poly_x = int_x
                        #     cur_poly_y = int_y
                        x_values.append(int_x)
                        y_values.append(int_y)  # Should be top elevation

                    if y1 < bottom_elevation:
                        cs_below_bottom = True
                        # find intersection point with the bottom of soil layer
                        result, int_x, int_y, _, _ = find_intersection_and_closest_index(
                            cs_stations, cs_elevations, bottom_elev_x, bottom_elev_y, i - 1, 0)
                        if result:
                            x_values.append(int_x)
                            y_values.append(int_y)
                        # close polygon
                        x_values.append(cur_poly_x)
                        y_values.append(bottom_elevation)
                        x_values.append(cur_poly_x)
                        y_values.append(cur_poly_y)
                        x_values.append(null_data)
                        y_values.append(null_data)
                        cur_poly_x = None
                    else:
                        # Coverage tests couldn't find - I think unreachable
                        # if cur_poly_x is None:
                        #     cur_poly_x = x1
                        #     cur_poly_y = y1
                        x_values.append(x1)
                        y_values.append(y1)

            elif cs_below_bottom:
                # Coverage tests couldn't find - I think unreachable
                # if y1 < bottom_elevation:
                #     continue
                # else:
                cs_below_bottom = False
                # find intersection point with the bottom of soil layer
                result, int_x, int_y, _, _ = find_intersection_and_closest_index(
                    cs_stations, cs_elevations, bottom_elev_x, bottom_elev_y, i - 1, 0)
                if result:
                    if cur_poly_x is None:
                        cur_poly_x = int_x
                        cur_poly_y = int_y
                    x_values.append(int_x)
                    y_values.append(int_y)
                if y1 > top_elevation:
                    cs_above_top = True
                    # find intersection point with the top of the soil layer
                    result, int_x, int_y, _, _ = find_intersection_and_closest_index(
                        cs_stations, cs_elevations, top_elev_x, top_elev_y, i - 1, 0)
                    if result:
                        # Coverage tests couldn't find - I think unreachable
                        # if cur_poly_x is None:
                        #     cur_poly_x = int_x
                        #     cur_poly_y = int_y
                        x_values.append(int_x)
                        y_values.append(int_y)

                        # Continue to the next point
                else:
                    x_values.append(x1)
                    y_values.append(y1)

            else:
                if y1 < bottom_elevation:
                    cs_below_bottom = True
                    # find intersection point with the bottom of soil layer
                    result, int_x, int_y, _, _ = find_intersection_and_closest_index(
                        cs_stations, cs_elevations, bottom_elev_x, bottom_elev_y, i - 1, 0)
                    if result:
                        x_values.append(int_x)
                        y_values.append(int_y)
                    # close polygon
                    x_values.append(cur_poly_x)
                    y_values.append(cur_poly_y)
                    x_values.append(null_data)
                    y_values.append(null_data)
                    cur_poly_x = None
                elif y1 > top_elevation:
                    cs_above_top = True
                    # find intersection point with the top of the soil layer
                    result, int_x, int_y, _, _ = find_intersection_and_closest_index(
                        cs_stations, cs_elevations, top_elev_x, top_elev_y, i - 1, 0)
                    if result:
                        x_values.append(int_x)
                        y_values.append(int_y)
                else:
                    x_values.append(x1)
                    y_values.append(y1)
        if cs_above_top:
            x_values.append(max_station)
            y_values.append(top_elevation)

        if not cs_below_bottom:
            # close polygon
            x_values.append(max_station)
            y_values.append(bottom_elevation)
            x_values.append(cur_poly_x)
            y_values.append(bottom_elevation)
            if bottom_elevation != cur_poly_y:
                x_values.append(cur_poly_x)
                y_values.append(cur_poly_y)
        # cs_below_bottom means the polygon has already been closed

        return x_values, y_values

    def assign_color_based_on_gradation(self, layer):
        """Assigns a color to the gradation based on the gradation.

        Args:
            layer (GradationLayerData): the current gradation layer
        """
        classification = None
        # if 'Gradation' in layer.results and 'Classification' in layer.results['Gradation']:
        if 'Results' in layer and 'Gradation' in layer['Results'] and 'Classification' in layer['Results']['Gradation']:
            classification = layer['Results']['Gradation']['Classification']
            if classification == "":
                return None, None, None
        else:
            return None, None, None

        if 'Prior classification' in layer['Results']['Gradation']:
            prior_classification = layer['Results']['Gradation']['Prior classification']
            if prior_classification == classification:
                return None, None, None

        color = None
        fill_color = None
        pattern = None

        if 'clay' in classification.lower() or 'silt' in classification.lower():
            _, color = self.get_data('Silty soil plot color')
            _, fill_color = self.get_data('Silty soil plot fill color')
            pattern = 'earth'
        elif 'sand' in classification.lower():
            _, color = self.get_data('Sandy soil plot color')
            _, fill_color = self.get_data('Sandy soil plot fill color')
            pattern = 'sand'
        elif 'gravel' in classification.lower():
            _, color = self.get_data('Gravel plot color')
            _, fill_color = self.get_data('Gravel plot fill color')
            pattern = 'small dots'
        elif 'riprap' in classification.lower():
            _, color = self.get_data('Riprap plot color')
            _, fill_color = self.get_data('Riprap plot fill color')
            pattern = 'rock'

        # if category == 'Cohesive':
        #     _, color = self.get_setting('Cohesive soil plot color')
        #     _, fill_color = self.get_setting('Cohesive soil plot fill color')
        #     pattern = 'clay'
        # elif category == 'silt':
        #     _, color = self.get_setting('silty soil plot color')
        #     _, fill_color = self.get_setting('silty soil plot fill color')
        #     pattern = 'earth'
        # elif category == 'sand':
        #     _, color = self.get_setting('sandy soil plot color')
        #     _, fill_color = self.get_setting('sandy soil plot fill color')
        #     pattern = 'sand'
        # elif category == 'Granular':
        #     _, color = self.get_setting('Granular plot color')
        #     _, fill_color = self.get_setting('Granular plot fill color')
        #     pattern = 'diamond'
        # elif category == 'Gravel':
        #     _, color = self.get_setting('Gravel plot color')
        #     _, fill_color = self.get_setting('Gravel plot fill color')
        #     pattern = 'small dots'
        # elif category == 'Cobble':
        #     _, color = self.get_setting('Cobble plot color')
        #     _, fill_color = self.get_setting('Cobble plot fill color')
        #     pattern = 'large dots'
        # elif category == 'Riprap':
        #     _, color = self.get_setting('Riprap plot color')
        #     _, fill_color = self.get_setting('Riprap plot fill color')
        #     pattern = 'rock'

        layer['Results']['Gradation']['Prior classification'] = classification
        return color, fill_color, pattern

    def get_current_layer_dict(self, index=None):
        """Returns the current layer dictionary.

        Returns:
            dict: Current layer dictionary
        """
        if index is None:
            index = self.current_layer
        if index < len(self.input_dict['calc_data']['Layers']):
            return self.input_dict['calc_data']['Layers'][index]
        return None

    def get_critical_shear_stress(self, index=None):
        """Returns the critical shear stress of the gradation.

        Args:
            index (int): index of the gradation layer

        Returns:
            critical shear stress (float): critical shear stress of the gradation
        """
        if index is None:
            index = self.current_layer
        if index < len(self.input_dict['calc_data']['Layers']):
            if self.input_dict['calc_data']['Layers'][index]['calculator'].percent_pass is None or \
                    self.input_dict['calc_data']['Layers'][index]['calculator'].tau <= 0.0:
                self.input_dict['calc_data']['Layers'][index]['calculator'].compute_data()
            return self.input_dict['calc_data']['Layers'][index]['calculator'].tau
        return None

    def get_d50(self, index=None):
        """Returns the D50 of the gradation.

        Returns:
            D50 (float): D50 of the gradation
        """
        if index is None:
            index = self.current_layer
        if index < len(self.input_dict['calc_data']['Layers']):
            if self.input_dict['calc_data']['Layers'][index]['calculator'].percent_pass is None or \
                    self.input_dict['calc_data']['Layers'][index]['calculator'].d50 <= 0.0:
                self.input_dict['calc_data']['Layers'][index]['calculator'].compute_data()
            return self.input_dict['calc_data']['Layers'][index]['calculator'].d50
        return None

    def get_d84(self, index=None):
        """Returns the D84 of the gradation.

        Returns:
            D84 (float): D84 of the gradation
        """
        if index is None:
            index = self.current_layer
        if index < len(self.input_dict['calc_data']['Layers']):
            if self.input_dict['calc_data']['Layers'][index]['calculator'].percent_pass is None or \
                    self.input_dict['calc_data']['Layers'][index]['calculator'].d84 <= 0.0:
                self.input_dict['calc_data']['Layers'][index]['calculator'].compute_data()
            return self.input_dict['calc_data']['Layers'][index]['calculator'].d84
        else:
            return None
