"""PlotOptions Class."""
__copyright__ = "(C) Copyright Aquaveo 2020"
__license__ = "All rights reserved"

# 1. Standard Python modules
import copy

# 2. Third party modules
import matplotlib.hatch

# 3. Aquaveo modules

# 4. Local modules
from xms.FhwaVariable.core_data.calculator.calcdata import VariableGroup
from xms.FhwaVariable.core_data.calculator.calculator_list import CalcOrVarlist
from xms.FhwaVariable.core_data.calculator.plot.custom_hatches import BituminousHatch, BrickHatch, ClayHatch, \
    ConcreteHatch, EarthHatch, GrassHatch, HoneycombHatch, MetalHatch, MetamorphicHatch, PeatHatch, RockHatch, \
    SandHatch, WaterHatch, WoodHatch, ZigzagHatch
from xms.FhwaVariable.core_data.calculator.plot.data_series_options import DataSeriesOptions
from xms.FhwaVariable.core_data.calculator.plot.line_options import LineOptions
from xms.FhwaVariable.core_data.calculator.plot.series_options import SeriesOptions
from xms.FhwaVariable.core_data.variables.variable import Variable


class PlotOptions(VariableGroup):
    """Class that defines the data for plotting options."""
    if BituminousHatch not in matplotlib.hatch._hatch_types:
        matplotlib.hatch._hatch_types.append(BituminousHatch)
    if BrickHatch not in matplotlib.hatch._hatch_types:
        matplotlib.hatch._hatch_types.append(BrickHatch)
    if ClayHatch not in matplotlib.hatch._hatch_types:
        matplotlib.hatch._hatch_types.append(ClayHatch)
    if ConcreteHatch not in matplotlib.hatch._hatch_types:
        matplotlib.hatch._hatch_types.append(ConcreteHatch)
    if EarthHatch not in matplotlib.hatch._hatch_types:
        matplotlib.hatch._hatch_types.append(EarthHatch)
    if GrassHatch not in matplotlib.hatch._hatch_types:
        matplotlib.hatch._hatch_types.append(GrassHatch)
    if HoneycombHatch not in matplotlib.hatch._hatch_types:
        matplotlib.hatch._hatch_types.append(HoneycombHatch)
    if MetalHatch not in matplotlib.hatch._hatch_types:
        matplotlib.hatch._hatch_types.append(MetalHatch)
    if MetamorphicHatch not in matplotlib.hatch._hatch_types:
        matplotlib.hatch._hatch_types.append(MetamorphicHatch)
    if PeatHatch not in matplotlib.hatch._hatch_types:
        matplotlib.hatch._hatch_types.append(PeatHatch)
    if RockHatch not in matplotlib.hatch._hatch_types:
        matplotlib.hatch._hatch_types.append(RockHatch)
    if SandHatch not in matplotlib.hatch._hatch_types:
        matplotlib.hatch._hatch_types.append(SandHatch)
    if WaterHatch not in matplotlib.hatch._hatch_types:
        matplotlib.hatch._hatch_types.append(WaterHatch)
    if WoodHatch not in matplotlib.hatch._hatch_types:
        matplotlib.hatch._hatch_types.append(WoodHatch)
    if ZigzagHatch not in matplotlib.hatch._hatch_types:
        matplotlib.hatch._hatch_types.append(ZigzagHatch)

    def __init__(self, plot_name='', app_data=None, model_name=None, project_uuid=None, input_dict=None,
                 show_series=True):
        """Initializes the Site Data.

        Args:
            plot_name (str): name of the plot
            app_data (AppData): application data (for settings)
            model_name (str): name of the model
            project_uuid (uuid): project uuid
            input_dict (dict): dictionary of input variables
            show_series (bool): whether to show the series
        """
        super().__init__(app_data=app_data, model_name=model_name, project_uuid=project_uuid)

        self.name = plot_name
        self.type = 'PlotOptions'

        # self.input_list = input_list
        self.show_series = show_series

        # Input
        self.linetypes = ['solid', 'dashed', 'dotted', 'dash-dot']

        # Other arguments (if any) can be added to args and kwargs
        args = ()  # Add positional arguments here if needed
        kwargs = {}  # Add other keyword arguments as needed

        self.input['Plot lines'] = Variable(
            'Plot lines', 'calc_list', CalcOrVarlist(
                LineOptions, app_data=app_data, model_name=model_name, project_uuid=project_uuid, show_define_btn=False,
                default_name='Plot line', default_plural_name='Plot lines', select_one=False, min_number_of_items=0,
                initial_num_items=0, show_number=False, show_duplicate=False, show_delete=False, args=args,
                kwargs=kwargs),
            complexity=1)

        if input_dict is not None:
            for series_group in input_dict:
                args = ()  # Add positional arguments here if needed
                kwargs = {'input_list': input_dict[series_group]}  # Add other keyword arguments as needed
                name = f'{series_group} data series'
                self.input['Data series'] = Variable(
                    'Data series', 'calc_list', CalcOrVarlist(
                        DataSeriesOptions, app_data=app_data, model_name=model_name, project_uuid=project_uuid,
                        default_name=name, default_plural_name=name, select_one=False,
                        min_number_of_items=0, show_name=False, show_number=False, show_duplicate=False,
                        show_delete=False, initial_num_items=1, args=args, kwargs=kwargs),
                    complexity=1)

        legend_locations = ['best', 'upper right', 'upper left', 'lower left', 'lower right', 'right',
                            'center left', 'center right', 'lower center', 'upper center', 'center']
        self.input['Include legend'] = Variable('Include legend', 'bool', True, complexity=1)
        self.input['Legend location'] = Variable('Legend location', 'list', 0, legend_locations, complexity=1)

        self.input['Log-scaled x-axis'] = Variable('Log-scaled x-axis', 'bool', False, complexity=1)
        self.input['Log-scaled y-axis'] = Variable('Log-scaled y-axis', 'bool', False, complexity=1)

        self.input['Determine X-Axis limits from data'] = Variable('Determine X-Axis limits from data', 'bool', True,
                                                                   complexity=1)
        self.input['Minimum X-Axis limit'] = Variable('Minimum X-Axis limit', 'float', 0.0, complexity=1)
        self.input['Maximum X-Axis limit'] = Variable('Maximum X-Axis limit', 'float', 0.0, complexity=1)
        self.input['Determine Y-Axis limits from data'] = Variable('Determine Y-Axis limits from data', 'bool', True,
                                                                   complexity=1)
        self.input['Minimum Y-Axis limit'] = Variable('Minimum X-Axis limit', 'float', 0.0, complexity=1)
        self.input['Maximum Y-Axis limit'] = Variable('Maximum X-Axis limit', 'float', 0.0, complexity=1)

        self.warnings = []
        self.results = {}
        # self.results['Results'] = Variable('Results', 'float_list', 0.0, [], precision=precision,
        #                                   unit_type=unit_type, native_unit=native_unit,
        #                                   us_units=us_units, si_units=si_units)

    def set_plot_point_options(self, related_index, index, name='', point_color=None, point_marker=None,
                               marker_size=None):
        """Set the plot options.

        Args:
            related_index (int): index of the plot
            index (int): index of the plot
            name (str): name of the plot
            point_color (color tuple): color of the point
            point_marker (str): marker type
            marker_size (float): size of the marker
        """
        if related_index == len(self.input['Data series'].value.item_list):
            self.input['Data series'].value.add_item_to_list()
        elif related_index >= len(self.input['Data series'].value.item_list):
            return

        data_name = name + ' data'
        self.input['Data series'].value.item_list[related_index].name = data_name

        self.input['Data series'].value.item_list[related_index].set_plot_point_options(
            index, name=name, point_color=point_color, point_marker=point_marker, marker_size=marker_size)

    def set_plot_line_options(self, index, name, line_intercept=None, line_alignment='vertical', line_color=None,
                              linetype=None, line_width=None, text_color=None, text_alignment=None,
                              multiple_labels=None, labels=None):
        """Set the plot options.

        Args:
            index (int): index of the plot
            name (str): name of the plot
            line_intercept (float): where the line will cross the axis
            line_alignment (string: 'vertical' or 'horizontal'): direction of line
            line_color (color tuple): line color
            linetype (str): line type ['solid', 'dashed', 'dotted', 'dash-dot';]
            line_width (float): line width
            text_color (color tuple): text color of label
            text_alignment (string): alignment of label
            multiple_labels (bool): whether to use multiple labels
            labels (list): list of labels
        """
        if index == len(self.input['Plot lines'].value.item_list):
            self.input['Plot lines'].value.add_item_to_list()
        elif index >= len(self.input['Plot lines'].value.item_list):
            return

        self.input['Plot lines'].value.item_list[index].set_plot_line_options(
            name=name, line_intercept=line_intercept, line_alignment=line_alignment, line_color=line_color,
            linetype=linetype, line_width=line_width, text_color=text_color, text_alignment=text_alignment,
            multiple_labels=multiple_labels, labels=labels)

    def set_plot_series_options(self, related_index, index, x_axis, y_axis, name='', line_color=None, linetype=None,
                                line_width=None, fill_below_line=None, fill_color=None, pattern=None, density=None):
        """Set the plot options.

        Args:
            related_index (int): index of the plot
            index (int): index of the plot
            x_axis (str): name of dataset to use for the y-axis data
            y_axis (str): name of dataset to use for the y-axis data
            name (str): name of the series
            line_color (color tuple): line color
            linetype (str): line type ['solid', 'dashed', 'dotted', 'dash-dot']
            line_width (float): line width
            fill_below_line (bool): fill below line
            fill_color (color tuple): fill color
            pattern (str): fill pattern
            density (int): density of the pattern
        """
        if related_index == len(self.input['Data series'].value.item_list):
            self.input['Data series'].value.add_item_to_list()
        elif related_index >= len(self.input['Data series'].value.item_list):
            return

        data_name = name + ' data'
        self.input['Data series'].value.item_list[related_index].name = data_name

        self.input['Data series'].value.item_list[related_index].set_plot_series_options(
            series_index=index, x_axis=x_axis, y_axis=y_axis, name=name, line_color=line_color, linetype=linetype,
            line_width=line_width, fill_below_line=fill_below_line, fill_color=fill_color, pattern=pattern,
            density=density)

    def set_plot_log_options(self, x_axis_log, y_axis_log):
        """Set the plot log options.

        Args:
            plot_name (str): name of the plot
            x_axis_log (bool): log x-axis
            y_axis_log (bool): log y-axis
        """
        self.input['Log-scaled x-axis'].set_val(x_axis_log)
        self.input['Log-scaled y-axis'].set_val(y_axis_log)

    def get_can_compute(self):
        """Determine whether we have enough data to compute.

        Returns:
            True, if we can compute; otherwise, False
        """
        return True, ''

    def compute_data(self):
        """Computes the data possible; stores results in self.

        Returns:
            bool: True if successful
        """
        return False

    def get_input_group(self, unknown=None):
        """Get the input group (for user-input).

        Returns
            input_vars (list of variables): input group of variables
        """
        input_vars = copy.deepcopy(self.input)

        if input_vars['Determine X-Axis limits from data'].get_val():
            input_vars.pop('Minimum X-Axis limit')
            input_vars.pop('Maximum X-Axis limit')
        if input_vars['Determine Y-Axis limits from data'].get_val():
            input_vars.pop('Minimum Y-Axis limit')
            input_vars.pop('Maximum Y-Axis limit')

        if not self.show_series:
            input_vars.pop('Data series')

        return input_vars

    def get_item_by_name(self, item_name):
        """Gets an item by its name.

        Args:
            item_name (str): name of the item to get

        Returns:
            item: the item
        """
        for item in self.input['Data series'].value.item_list:
            if item.name == item_name:
                return item
            result = item.get_item_by_name(item_name)
            if result is not None:
                return result

        result = self.input['Plot lines'].value.get_item_by_name(item_name)
        return result

    def set_item_by_name(self, new_item):
        """Sets an item by its name.

        Args:
            new_item (?): item of data to set

        Returns:
            item: the item
        """
        for index, item in enumerate(self.input['Data series'].value.item_list):
            if item.name == new_item.name:
                self.input['Data series'].value.item_list[index] = copy.deepcopy(new_item)
                return True
            result = item.set_item_by_name(new_item)
            if result:
                return True

        result = self.input['Plot lines'].value.set_item_by_name(new_item)
        return result

    def get_item_by_indices(self, data_series_index, series_index):
        """Gets an item by its name.

        Args:
            data_series_index (int): index of the data series
            series_index (int): index of the series

        Returns:
            item: the item
        """
        if data_series_index < len(self.input['Data series'].value.item_list):
            return self.input['Data series'].value.item_list[data_series_index].get_item_by_indices(series_index)
        return None

    def set_item_by_indices(self, data_series_index, series_index, new_item):
        """Sets an item by its name.

        Args:
            data_series_index (int): index of the data series
            series_index (int): index of the series
            new_item (?): item of data to set

        Returns:
            item: the item
        """
        if data_series_index < len(self.input['Data series'].value.item_list):
            return self.input['Data series'].value.item_list[data_series_index].set_item_by_indices(
                series_index, new_item)
        return False

    @staticmethod
    def convert_linetype_text_to_symbol(linetype):
        """Convert the line type text to a symbol.

        Args:
            linetype (str): line type text

        Returns:
            str: line type symbol
        """
        return SeriesOptions.convert_linetype_text_to_symbol(linetype)

    @staticmethod
    def convert_pattern_text_to_symbol(pattern, density):
        """Convert the line type text to a symbol.

        Args:
            pattern (str): line type text
            density (int): density of the pattern

        Returns:
            str: pattern symbol
        """
        return SeriesOptions.convert_pattern_text_to_symbol(pattern, density)

    def get_plot_options_dict(self):
        """Get the plot options dictionary."""
        plot_options = {}
        plot_options['series'] = {}
        plot_options['lines'] = {}
        plot_options['points'] = {}
        index = 0
        if 'Data series' in self.input:
            num_data_series = len(self.input['Data series'].value.item_list)
            for i in range(num_data_series):
                data_item = self.input['Data series'].value.item_list[i]
                num_plot_series = len(data_item.input['Plot series'].value.item_list)
                for j in range(num_plot_series):
                    series_item = data_item.input['Plot series'].value.item_list[j]
                    if series_item.input['Display'].get_val():
                        plot_options['series'][index] = series_item.get_plot_options_dict()
                        index += 1
                num_plot_points = len(data_item.input['Plot points'].value.item_list)
                for j in range(num_plot_points):
                    points_item = data_item.input['Plot points'].value.item_list[j]
                    if points_item.input['Display'].get_val():
                        plot_options['points'][index] = points_item.get_plot_options_dict()
                        index += 1

        if 'Plot lines' in self.input:
            num_line_items = len(self.input['Plot lines'].value.item_list)
            for i in range(num_line_items):
                line_items = self.input['Plot lines'].value.item_list[i]
                if line_items.input['Display'].get_val():
                    plot_options['lines'][index] = line_items.get_plot_options_dict()
                    index += 1

        plot_options['Include legend'] = self.input['Include legend'].get_val()
        plot_options['Legend location'] = self.input['Legend location'].get_val()

        plot_options['Log-scaled x-axis'] = self.input['Log-scaled x-axis'].get_val()
        plot_options['Log-scaled y-axis'] = self.input['Log-scaled y-axis'].get_val()

        plot_options['Determine X-Axis limits from data'] = self.input['Determine X-Axis limits from data'].get_val()
        plot_options['Minimum X-Axis limit'] = self.input['Minimum X-Axis limit'].get_val()
        plot_options['Maximum X-Axis limit'] = self.input['Maximum X-Axis limit'].get_val()

        plot_options['Determine Y-Axis limits from data'] = self.input['Determine Y-Axis limits from data'].get_val()
        plot_options['Minimum Y-Axis limit'] = self.input['Minimum Y-Axis limit'].get_val()
        plot_options['Maximum Y-Axis limit'] = self.input['Maximum Y-Axis limit'].get_val()

        return plot_options
