"""Weir Class."""
__copyright__ = "(C) Copyright Aquaveo 2020"
__license__ = "All rights reserved"

# 1. Standard Python modules
import copy
import sys

# 2. Third party modules

# 3. Aquaveo modules
from xms.FhwaVariable.core_data.calculator.calcdata import CalcData
from xms.FhwaVariable.core_data.calculator.table_data import TableData
from xms.FhwaVariable.core_data.variables.user_array import UserArray
from xms.FhwaVariable.core_data.variables.variable import Variable

# 4. Local modules
from xms.HydraulicToolboxCalc.hydraulics.weir.weir_calc import WeirCalc


class WeirData(CalcData):
    """A class that defines a weir and performs weir computations."""

    def __init__(self, stand_alone_calc=True, allowed_shapes=None, app_data=None, model_name=None,
                 project_uuid=None):
        """Initializes the weir calculator."""
        super().__init__(app_data=app_data, model_name=model_name, project_uuid=project_uuid)

        self.name = 'Weir Calculator'
        self.type = 'WeirCalc'
        self.class_name = 'Weir Calculator'

        self.calculator = WeirCalc()

        if allowed_shapes is None:
            allowed_shapes = ['irregular', 'rectangular', 'trapezoidal', 'triangular']

        self.calc_support_dual_input = False

        max_value = sys.float_info.max

        self.stand_alone_calc = stand_alone_calc

        # Input
        self.input['Calculate'] = Variable('Calculate', "list", 1, ['Head', 'Flows'])
        self.input['Head'] = Variable('Head', "list", 0, ['Depth', 'Elevation'], complexity=2)

        self.input['Flows'] = Variable('Flow(s)', 'UserArray', UserArray(
            2, ['flow'], 'cfs', us_units=self.us_flow, si_units=self.si_flow, select_name='Specify flow(s) as:',
            name_append='flow'))
        self.input['WSE'] = Variable('Water surface elevation(s)', 'UserArray', UserArray(
            2, ['length'], 'ft', us_units=self.us_mid_length, si_units=self.si_mid_length,
            select_name='Specify elevation(s) as:', name_append='Elevation'), limits=(-max_value, max_value))
        self.input['Depths'] = Variable('Depth(s)', 'UserArray', UserArray(
            2, ['length'], 'ft', us_units=self.us_mid_length, si_units=self.si_mid_length,
            select_name='Specify depth(s) as:', name_append='Depth'))

        self.input['Weir orientation'] = Variable(
            'Weir orientation', "list", 0, ['Perpendicular (intercepts all flow)', 'Parallel (flow can bypass weir)'],
            complexity=2)

        self.input['Velocity'] = Variable(
            'Approach velocity', 'UserArray', UserArray(
                2, ['velocity'], 'ft/s', us_units=self.us_velocity, si_units=self.si_velocity,
                select_name='Specify velocity as:', name_append='Velocity'),
            complexity=1, note='Velocity head is included in head calculations; provide 0.0 for pool or '
            'conservative flow conditions')

        self.input['Froude'] = Variable(
            'Froude number', 'UserArray', UserArray(
                precision=4, unit_type=[], native_unit='', us_units=[['']], si_units=[[]],
                select_name='Specify Froude number as:', name_append='Froude'), complexity=2,
            note='Froude number based on parallel flow conditions; assumes weir length is short so drawdown is '
            'negligible')

        self.irregular_shapes = ['irregular']
        self.rectangular_shapes = [
            'rectangular', 'rectangular (sharp-crested)', 'rectangular (broad-crested)'
        ]
        self.trapezoidal_shapes = ['Cipolletti', 'trapezoidal']
        self.triangular_shapes = [
            'V-notch - 90 degrees', 'V-notch - 60 degrees', 'V-notch - 45 degrees', 'V-notch - 22.5 degrees',
            'V-notch - User-defined'
        ]
        default_shape = 0
        self.using_shapes = []
        if 'irregular' in allowed_shapes:
            self.using_shapes.extend(self.irregular_shapes)
            default_shape += 1  # Have the default shape be rectangular
        if 'rectangular' in allowed_shapes:
            self.using_shapes.extend(self.rectangular_shapes)
        if 'trapezoidal' in allowed_shapes:
            self.using_shapes.extend(self.trapezoidal_shapes)
        if 'triangular' in allowed_shapes:
            self.using_shapes.extend(self.triangular_shapes)

        self.input['Weir type'] = Variable('Weir type', "list", default_shape, self.using_shapes)

        self.input['Weir length'] = Variable(
            'Weir length', 'float', 0, [], precision=2, unit_type=['length'], native_unit='ft',
            us_units=self.us_mid_length, si_units=self.si_mid_length)
        self.input['Weir width'] = Variable(
            'Weir width', 'float', 0, [], precision=2, unit_type=['length'], native_unit='ft',
            us_units=self.us_mid_length, si_units=self.si_mid_length)
        self.input['Side slope'] = Variable(
            'Side slope', 'float', 0, [], precision=2, unit_type=['slope'], native_unit='H:1V', us_units=[['H:1V']],
            si_units=[['H:1V']])
        self.input['Select angle or coef'] = Variable(
            'Select angle or coefficient', "list", 0, ['v-notch angle', 'Weir coefficient'])

        self.input['Notch angle'] = Variable(
            'Angle of v-notch', 'float', 0, [], (0.0, 89.999), precision=2, unit_type=['angle'], native_unit='degrees',
            us_units=[['degrees']], si_units=[['degrees']])
        # TODO: Support more units for angle (at least radians)

        self.input['Weir height'] = Variable(
            'Weir height (channel invert to weir crest)', 'float', 0, [], precision=2, unit_type=['length'],
            native_unit='ft', us_units=self.us_mid_length, si_units=self.si_mid_length)

        # self.input['Crest stations'] = Variable(
        #     'Crest stations', 'float_list', 0, [0.0, 0.0, 0.0], limits=(-max_value, max_value), precision=2,
        #     unit_type=['length'], native_unit='ft', us_units=self.us_mid_length, si_units=self.si_mid_length)
        # self.input['Crest elevations'] = Variable(
        #     'Crest elevations', 'float_list', 0, [0.0, 0.0, 0.0], precision=2,
        #     limits=(-max_value, max_value), unit_type=['length'], native_unit='ft', us_units=self.us_mid_length,
        #     si_units=self.si_mid_length)

        # structure_color = self.theme['Plot structure color']
        # structure_fill_color = self.theme['Plot structure fill color']

        # soil_color = self.theme['Plot soil color']
        # soil_fill_color = self.theme['Plot soil fill color']

        embankment_color = self.theme['Plot embankment color']
        embankment_fill_color = self.theme['Plot embankment fill color']

        crest_geom = {
            'Crest stations': Variable(
                'Crest stations', 'float_list', 0, [0.0, 0.0, 0.0], limits=(-max_value, max_value), precision=2,
                unit_type=['length'], native_unit='ft', us_units=self.us_mid_length, si_units=self.si_mid_length),
            'Crest elevations': Variable(
                'Crest elevations', 'float_list', 0, [0.0, 0.0, 0.0], precision=2, limits=(-max_value, max_value),
                unit_type=['length'], native_unit='ft', us_units=self.us_mid_length, si_units=self.si_mid_length)
        }

        #   Irregular Weir Crest Geometry
        name = 'Irregular weir crest geometry'
        self.input[name] = Variable(
            name, 'table', TableData(self.theme, name=name, plot_names=['Bridge cross-section'],
                                     input=copy.deepcopy(crest_geom), min_items=2,
                                     app_data=app_data, model_name=model_name, project_uuid=project_uuid))
        self.input[name].get_val().set_plot_series_options(
            name, related_index=0, index=0, name=name, x_axis='Station', y_axis='Elevation',
            line_color=embankment_color, linetype='solid', line_width=1.5, fill_below_line=False,
            fill_color=embankment_fill_color, pattern='sand')

        self.input['Tailwater'] = Variable(
            'Tailwater depth (above crest)', 'UserArray', UserArray(
                2, ['length'], 'ft', us_units=self.us_mid_length, si_units=self.si_mid_length,
                select_name='specify elevation(s) as:', name_append='Elevation'), limits=(-max_value, max_value))
        self.input['Tailwater elevations'] = Variable(
            'Tailwater surface elevation(s)', 'UserArray', UserArray(
                2, ['length'], 'ft', us_units=self.us_mid_length, si_units=self.si_mid_length,
                select_name='specify elevation(s) as:', name_append='Elevation'), limits=(-max_value, max_value))

        self.input['Weir surface'] = Variable(
            'Weir surface', "list", 0, ['paved', 'gravel', 'user-defined'])

        self.input['Weir coefficient'] = Variable(
            'Weir coefficient', 'float', 3.10, [], limits=[0.0, 10.0], precision=4, unit_type=['coefficient'],
            native_unit='coefficient', us_units=[['coefficient']], si_units=[['coefficient']])

        self.input['Thalweg invert elevation'] = Variable(
            'Channel thalweg invert elevation', 'float', 0, [], limits=(-max_value, max_value), complexity=2,
            precision=2, unit_type=['length'], native_unit='ft', us_units=self.us_mid_length,
            si_units=self.si_mid_length)

        self.unknown = None

        # Intermediate
        self.compute_prep_functions.extend([])
        self.compute_finalize_functions.extend([self.assign_unknown_results])
        self.intermediate_to_copy.extend([
            'coefficient', 'computed_coefficient', 'submerged_coefficient', 'prior_weir_type', 'unk_flow',
            'irregular_shapes', 'rectangular_shapes', 'trapezoidal_shapes', 'triangular_shapes', 'using_shapes',
            'weir_coefficients', 'applied_coefficients', 'perpendicular_coefficient', 'perpendicular_coefficients',
        ])

        self.coefficient = 0.0
        self.computed_coefficient = 0.0
        self.submerged_coefficient = 0.0
        self.prior_weir_type = ''
        # self.unk_flow = {}

        self.weir_coefficients = []
        self.applied_coefficients = []

        self.perpendicular_coefficient = 0.0
        self.perpendicular_coefficients = []

        self.warnings = []

        # Results
        self.results = {}

        self.results['Flows'] = Variable(
            'Flow(s)', 'float_list', 0.0, [], unit_type=['flow'], native_unit='cfs', us_units=self.us_flow,
            si_units=self.si_flow)
        self.results['WSE'] = Variable(
            'Water surface elevation(s)', 'float_list', 0.0, [], precision=2, unit_type=['length'],
            native_unit='ft', us_units=self.us_mid_length, si_units=self.si_mid_length)
        self.results['Depths'] = Variable(
            'Depth(s)', 'float_list', 0.0, [], precision=2, unit_type=['length'], native_unit='ft',
            us_units=self.us_mid_length, si_units=self.si_mid_length)

        self.results['WSE stations'] = Variable(
            'Water surface station(s)', 'float_list', 0.0, [], precision=2, unit_type=['length'], native_unit='ft',
            us_units=self.us_mid_length, si_units=self.si_mid_length)

        self.results['Weir coefficient'] = Variable(
            'Weir coefficient', 'float_list', 0.0, [], limits=[0.0, 10.0], precision=4, unit_type=['coefficient'],
            native_unit='coefficient', us_units=[['coefficient']], si_units=[['coefficient']])

        self.results['Weir coefficients'] = Variable(
            'Computed weir coefficient', 'list_of_lists', 0.0, [], limits=[0.0, 10.0], precision=4,
            unit_type=['coefficient'], native_unit='coefficient', us_units=[['coefficient']],
            si_units=[['coefficient']])

        self.results['Applied coefficients'] = Variable(
            'Applied coefficients (includes submergence)', 'list_of_lists', 0.0, [], limits=[0.0, 10.0], precision=4,
            unit_type=['coefficient'], native_unit='coefficient', us_units=[['coefficient']],
            si_units=[['coefficient']])

        # Irregular Weir input cross-section
        soil_color = (102, 51, 0)
        soil_fill_color = (215, 211, 199)
        if self.theme is not None:
            soil_color = self.theme['Plot soil color']
            soil_fill_color = self.theme['Plot soil fill color']

        self.station_var = Variable('Station', 'float_list', 0, [0.0], precision=2,
                                    limits=(-sys.float_info.max, sys.float_info.max), unit_type=['length'],
                                    native_unit='ft', us_units=self.us_mid_length, si_units=self.si_mid_length)
        self.elevation_var = Variable('Elevation', 'float_list', 0, [0.0], precision=2,
                                      limits=(-sys.float_info.max, sys.float_info.max), unit_type=['length'],
                                      native_unit='ft', us_units=self.us_mid_length, si_units=self.si_mid_length)

        cross_section = {
            'Station': self.station_var,
            'Elevation': self.elevation_var
        }
        #   User Cross-section
        name = 'Cross-section'
        self.input[name] = Variable(
            name, 'table', TableData(self.theme, name=name, plot_names=['Cross-section'],
                                     input=copy.deepcopy(cross_section), min_items=3,
                                     app_data=app_data, model_name=model_name, project_uuid=project_uuid))
        self.input[name].get_val().set_plot_series_options(
            name, related_index=0, index=0, name=name, x_axis='Station', y_axis='Elevation',
            line_color=soil_color, linetype='solid', line_width=1.5, fill_below_line=False,
            fill_color=soil_fill_color, pattern='sand')

        # plot
        self.plots['profile'] = {}
        self.plots['profile']['Plot name'] = "Weir Profile"
        self.plots['profile']['Legend'] = 'best'

    def get_input_group(self, unknown=None):
        """Returns the input group for the user interface.

        Args:
            unknown (string): variable that is unknown

        Returns:
            input_vars (dict of variables): Returns the dict of input to be shown.
        """
        input_vars = {}

        unknown = self.input['Calculate'].get_val()
        head = self.input['Head'].get_val()

        if self.stand_alone_calc:
            self.input['Calculate'] = Variable('Calculate', "list", 1, ['Head', 'Flows'])
            input_vars['Calculate'] = self.input['Calculate']
            input_vars['Head'] = self.input['Head']

            if unknown != 'Flows':
                input_vars['Flows'] = self.input['Flows']
            if unknown != 'Head':
                if head == 'Elevation':
                    input_vars['WSE'] = self.input['WSE']
                else:
                    input_vars['Depths'] = self.input['Depths']
        else:
            input_vars['Head'] = self.input['Head']

        input_vars['Weir orientation'] = self.input['Weir orientation']
        if self.input['Weir orientation'].get_val() == 'Parallel (flow can bypass weir)':
            input_vars['Froude'] = self.input['Froude']
        else:
            input_vars['Velocity'] = self.input['Velocity']

        input_vars['Weir type'] = self.input['Weir type']

        weir_type = self.input['Weir type'].get_val()

        if weir_type in self.rectangular_shapes:
            input_vars['Weir length'] = self.input['Weir length']
            input_vars['Weir coefficient'] = self.input['Weir coefficient']
        elif weir_type in self.trapezoidal_shapes:
            input_vars['Weir length'] = self.input['Weir length']
            if weir_type == 'trapezoidal':
                input_vars['Side slope'] = self.input['Side slope']
        elif weir_type in self.triangular_shapes:
            if weir_type == 'V-notch - User-defined':
                input_vars['Select angle or coef'] = self.input['Select angle or coef']
                if input_vars['Select angle or coef'].get_val() == 'v-notch angle':
                    input_vars['Notch angle'] = self.input['Notch angle']
                else:
                    input_vars['Weir coefficient'] = self.input['Weir coefficient']
        elif weir_type in self.irregular_shapes:
            input_vars['Weir width'] = self.input['Weir width']
            # input_vars['Crest stations'] = self.input['Crest stations']
            # input_vars['Crest elevations'] = self.input['Crest elevations']
            input_vars['Irregular weir crest geometry'] = self.input['Irregular weir crest geometry']
            input_vars['Weir surface'] = self.input['Weir surface']

        if weir_type == 'rectangular (sharp-crested)':
            input_vars.pop('Weir coefficient')
            input_vars['Weir height'] = self.input['Weir height']

        if head == 'Elevation':
            input_vars['Tailwater elevations'] = self.input['Tailwater elevations']
        else:
            input_vars['Tailwater'] = self.input['Tailwater']

        if input_vars['Weir type'].get_val() not in self.irregular_shapes:
            input_vars['Thalweg invert elevation'] = self.input['Thalweg invert elevation']

        if self.stand_alone_calc:
            # Add all calculation options
            calculate_list = self.input['Calculate'].get_list()

            # Add geometry variables to the list
            # If that item is selected to calculate, remove from input dictionary (but not in subclasses)
            for item in input_vars:
                if item not in [
                    'Calculate', 'Head', 'Flows', 'WSE', 'Depths',
                    'Weir type', 'Select angle or coef', 'Weir surface', 'Thalweg invert elevation',
                    'Irregular weir crest geometry', 'Velocity', 'Weir orientation', 'Froude'
                ]:
                    calculate_list.append(item)

            self.input['Calculate'].value_options = calculate_list
            if unknown in calculate_list:
                self.input['Calculate'].set_val(unknown)
            input_vars['Calculate'] = self.input['Calculate']

            unknown = input_vars['Calculate'].get_val()

        if unknown in input_vars:
            input_vars.pop(unknown)

        return input_vars

    def get_results_group(self, unknown=None):
        """Returns a dictionary of input variables that are needed for current selections.

        Args:
            unknown (string): the variable that is unknown (and included in the result dictionary)

        Returns:
              result_vars (dictionary of variables): the input variables
        """
        result_vars = {}

        if not self.can_compute:
            return result_vars

        if unknown is None:
            unknown = self.input['Calculate'].get_val()

        if unknown in self.results:
            result_vars[unknown] = self.results[unknown]

        if unknown == 'Flows':
            result_vars['Flows'] = self.results['Flows']
            if self.input['Head'].get_val() == 'Depth':
                result_vars['WSE'] = self.results['WSE']
            else:
                result_vars['Depths'] = self.results['Depths']
        elif unknown in ['Head', 'WSE', 'Depth']:
            if self.input['Head'].get_val() == 'Depth':
                result_vars['WSE'] = self.results['WSE']
            else:
                result_vars['Depths'] = self.results['Depths']
        else:
            if self.input['Head'].get_val() == 'Depth':
                result_vars['WSE'] = self.results['WSE']
            else:
                result_vars['Depths'] = self.results['Depths']

        weir_type = self.input['Weir type'].get_val()
        if weir_type in self.irregular_shapes:
            result_vars['Weir coefficients'] = self.results['Weir coefficients']
            result_vars['Applied coefficients'] = self.results['Applied coefficients']
        else:
            result_vars['Weir coefficient'] = self.results['Weir coefficient']

        return result_vars

    def check_warnings(self):
        """Checks for warnings that are given during computations or a check if we can compute (get_can_compute).

        Returns:
            list of str: The warnings found (if any)
        """
        return self.warnings

    def assign_unknown_results(self):
        """Assigns the results for unknown variables."""
        # Remove any results created from calculating an unknown other than depth or flow:
        self.results = {k: v for k, v in self.results.items() if k in self.calculator.original_results}

        self.unknown = self.input['Calculate'].get_val()
        if self.unknown not in self.results and self.unknown in self.calculator.results:
            self.results[self.unknown] = copy.copy(self.input[self.unknown])
            self.results[self.unknown].type = 'float_list'
            self.results[self.unknown].value_options = self.calculator.results[self.unknown]
            self.results[self.unknown].value = len(self.calculator.results[self.unknown])
