"""Dialog for viewing the observation target data."""

__copyright__ = "(C) Copyright Aquaveo 2020"
__license__ = "All rights reserved"

# 1. Standard Python modules
import webbrowser

# 2. Third party modules
from matplotlib.backends.backend_qt5agg import FigureCanvas, NavigationToolbar2QT as NavigationToolbar
import matplotlib.dates as mdates
from matplotlib.figure import Figure
import numpy as np
from PySide2.QtWidgets import QHeaderView

# 3. Aquaveo modules
from xms.guipy.dialogs import xms_parent_dlg
from xms.guipy.dialogs.xms_parent_dlg import XmsDlg
from xms.guipy.settings import SettingsManager
from xms.guipy.widgets.widget_builder import style_table_view

# 4. Local modules
from xms.coverage.components.generic_coverage_component import FEATURE_TYPE_TEXT
from xms.coverage.data.obs_target_data import ObsTargetData
from xms.coverage.gui.obs_target_dialog_ui import Ui_ObsTargetDialog
from xms.coverage.gui.obs_target_model import ObsTargetModel


FEATURE_IDX_TO_TEXT = {  # Entity type text defined by XMS
    ObsTargetData.OBS_CATEGORY_POINT: 'POINT',
    ObsTargetData.OBS_CATEGORY_ARC: 'ARC',
    ObsTargetData.OBS_CATEGORY_ARC_GROUP: 'ARC_GROUP',
    ObsTargetData.OBS_CATEGORY_POLY: 'POLYGON',
}

# Constants for combobox options that don't correspond to the ObsTargetData.OBS_CATEGORY_* enum
CBX_IDX_ALL_FEATURES = ObsTargetData.OBS_CATEGORY_POLY + 1  # The 'All' feature type filter option
CBX_IDX_SELECTIONS_SHOW_SELECTED = 0  # The 'Selected' show type filter option
CBX_IDX_SELECTIONS_SHOW_ALL = 1


class ObsTargetDialog(XmsDlg):
    """A dialog to observation target data."""
    def __init__(self, parent, target_data, selected_ids, time_format):
        """Initializes the dialog.

        Args:
            parent (QWidget): Parent Qt dialog
            target_data (xr.Dataset): The target data to view
            selected_ids (dict): The ids of the selected features. Key is the feature type xmscomponents string
                ('POINT', 'ARC', 'POLYGON', 'ARC_GROUP'), value is a list of the selected  features' component ids.
            time_format (str): The absolute datetime format to use (should use strftime specifiers)
        """
        super().__init__(parent, 'xms.coverage.gui.obs_target_dialog')
        self._help_url = 'https://www.xmswiki.com/wiki/SMS:Display_Options'
        self._time_format = time_format
        self._selected_ids = selected_ids
        self._all_data = target_data
        self._current_selection = {}
        self._data = None  # Filtered data
        self._model = None
        self._figure = None  # matplotlib.figure Figure
        self._canvas = None  # matplotlib.backends.backend_qt5agg FigureCanvas
        self._ax = None  # matplotlib Axes

        # Setup the dialog
        self._in_setup = False
        self._ui = Ui_ObsTargetDialog()
        self._ui.setupUi(self)
        xms_parent_dlg.add_process_id_to_window_title(self)

        self._setup_ui()

    def _setup_ui(self):
        """Setup dialog widgets."""
        self._in_setup = True

        self._setup_splitter()
        self._setup_filter_widgets()
        self._setup_ui_table_view()

        # Plot
        self._ui.rbt_plot_computed_vs_observed.setChecked(True)
        self._ui.rbt_plot_computed_vs_observed.toggled.connect(self._on_filter_changed)
        self._ui.rbt_plot_residual_vs_observed.toggled.connect(self._on_filter_changed)
        self._ui.rbt_plot_time_series.toggled.connect(self._on_filter_changed)
        self._ui.chk_show_observed.setChecked(True)
        self._ui.chk_show_observed.setEnabled(False)
        self._ui.chk_show_observed.stateChanged.connect(self._on_filter_changed)
        self._ui.chk_show_legend.setChecked(False)
        self._ui.chk_show_legend.stateChanged.connect(self._on_show_legend)
        self.setup_plot()

        # Add the Ok and Help... buttons at the bottom of the dialog
        self._setup_ui_bottom_button_box()
        self._in_setup = False

        # Apply the initial filter after we are done setting everything up
        self._on_filter_changed(-1)

    def showEvent(self, event):  # noqa: N802
        """Restore last position and geometry when showing dialog."""
        super().showEvent(event)
        self._restore_splitter_geometry()

    def _setup_splitter(self):
        self._ui.splitter.setSizes([400, 400])
        self._ui.splitter.setChildrenCollapsible(False)
        self._ui.splitter.setStyleSheet(
            'QSplitter::handle:horizontal { background-color: lightgrey; }'
            'QSplitter::handle:vertical { background-color: lightgrey; }'
        )

    def _save_splitter_geometry(self):
        """Save the current position of the splitter."""
        settings = SettingsManager()
        settings.save_setting('xms.coverage', f'{self._dlg_name}.splitter', self._ui.splitter.sizes())

    def _restore_splitter_geometry(self):
        """Restore the position of the splitter."""
        settings = SettingsManager()
        splitter = settings.get_setting('xms.coverage', f'{self._dlg_name}.splitter')
        if not splitter:
            return
        splitter_sizes = [int(size) for size in splitter]
        self._ui.splitter.setSizes(splitter_sizes)

    def _setup_filter_widgets(self):
        self._ui.cbx_feature_type.currentIndexChanged.connect(self._on_filter_changed)
        self._ui.cbx_selections.currentIndexChanged.connect(self._on_filter_changed)
        self._ui.chk_show_null_values.stateChanged.connect(self._on_filter_changed)
        self._set_initial_filter()

    def _set_initial_filter(self):
        # Loop through all the selected features. If there is only one feature type selected, set the feature type
        # filter to that. If there is more than one feature type selected, set the feature type filter to 'All'.
        selection = self._get_selected_ids(CBX_IDX_ALL_FEATURES)  # Request all selected features
        cbx_idx = -1
        for feature_idx, feature_ids in selection.items():
            if not feature_ids:
                continue  # No selected feature of this type

            if cbx_idx == -1:
                cbx_idx = feature_idx  # This is the first type we have found with selected features, may not be others
            else:
                cbx_idx = CBX_IDX_ALL_FEATURES  # We have other feature types selected, set filter to 'All'

        if cbx_idx < 0:  # No selection (probably coverage-level dialog), disable selection filter combobox
            self._ui.cbx_selections.setCurrentIndex(CBX_IDX_SELECTIONS_SHOW_ALL)
            self._ui.cbx_selections.setEnabled(False)
            # Set the 'Feature' filter to all features by default
            self._ui.cbx_feature_type.setCurrentIndex(CBX_IDX_ALL_FEATURES)
        else:  # At least one feature selected, set the 'Show' filter to only the selected features by default
            self._ui.cbx_selections.setCurrentIndex(CBX_IDX_SELECTIONS_SHOW_SELECTED)
            self._ui.cbx_feature_type.setCurrentIndex(cbx_idx)

        self._ui.chk_show_null_values.setChecked(True)

    def _on_filter_changed(self, _index):
        if self._in_setup:  # Don't apply filters until we are ready
            return
        self._ui.chk_show_observed.setEnabled(self._ui.rbt_plot_time_series.isChecked())
        self._filter()
        self._update_tableview()
        self.add_series()

    def _on_show_legend(self):
        """Called when "Show legend" checkbox is clicked to hide/show the legend."""
        self.add_series()

    def _filter(self):
        """Use the combobox values to filter the table rows."""
        feature_index = self._ui.cbx_feature_type.currentIndex()
        selections_index = self._ui.cbx_selections.currentIndex()
        show_nulls = self._ui.chk_show_null_values.isChecked()
        self._data = None
        # If no filter is being applied, just use the entire source Dataset
        if feature_index == CBX_IDX_ALL_FEATURES and selections_index == CBX_IDX_SELECTIONS_SHOW_ALL and show_nulls:
            self._data = self._all_data
        else:  # There is at least one of the filter types being applied
            if feature_index != CBX_IDX_ALL_FEATURES:  # Filter by feature type
                self._data = self._all_data.where(self._all_data.feature_type == feature_index, drop=True)
            else:  # Show all feature types, implies we will be filtering to selected features
                self._data = self._all_data

            if not show_nulls:
                self._data = self._data.dropna(dim='comp_id')

            if selections_index != CBX_IDX_SELECTIONS_SHOW_ALL:  # Filter to selected features
                selection = self._get_selected_ids(feature_index)
                for feature_index, selected_ids in selection.items():
                    if not selected_ids:
                        continue
                    self._data = self._data.where(
                        (self._data.feature_type == feature_index) & self._data.comp_id.isin(selected_ids), drop=True
                    )

    def _update_tableview(self):
        df = self._data.to_dataframe()
        self._model = ObsTargetModel(data_frame=df, time_format=self._time_format, parent=self)
        self._ui.table_view.setModel(self._model)
        # Make the table read-only
        for col in range(self._model.columnCount()):
            self._model.read_only_columns.add(col)
        self._ui.table_view.setColumnHidden(ObsTargetModel.COL_IDX_FEATURE_TYPE, True)
        self._ui.table_view.setColumnHidden(ObsTargetModel.COL_IDX_CATEGORY, True)
        self._ui.table_view.resizeColumnsToContents()

    def _get_selected_ids(self, feature_index):
        """Get the selected ids of one feature type or all of them.

        Args:
            feature_index (int): One of the  ObsTargetData.OBS_CATEGORY_* enum values, or CBX_IDX_ALL_FEATURES to get
                all of the selected features

        Returns:
            dict: Key is the ObsTargetData.OBS_CATEGORY_* value of the feature type, value is a list of the selected
            features' component ids
        """
        if feature_index == CBX_IDX_ALL_FEATURES:  # Show all features
            feature_types = FEATURE_TYPE_TEXT
        else:
            feature_types = [FEATURE_IDX_TO_TEXT[feature_index]]
        feature_ids = {}
        for idx, feature_type in enumerate(feature_types):
            feature_idx = idx if feature_index == CBX_IDX_ALL_FEATURES else feature_index
            feature_ids[feature_idx] = self._selected_ids.get(feature_type, [])
        return feature_ids

    def _setup_ui_table_view(self):
        """Sets up the table view."""
        style_table_view(self._ui.table_view)
        self._ui.table_view.horizontalHeader().setSectionResizeMode(QHeaderView.Interactive)
        self._ui.table_view.horizontalHeader().setStretchLastSection(True)

    def _setup_ui_bottom_button_box(self):
        """Add buttons to the bottom of the dialog."""
        # Add Import and Export buttons
        self._ui.btn_box.accepted.connect(self.accept)
        self._ui.btn_box.helpRequested.connect(self.help_requested)

    def _add_navigation_toolbar(self):
        """Add the built-in matplotlib navigation toolbar but hide stuff we don't want."""
        toolbar = NavigationToolbar(self._canvas, self)
        # Remove the "Configure subplots" button. We aren't really sure what this does.
        toolbar.toolitems = [toolitem for toolitem in toolbar.toolitems if toolitem[0] != 'Subplots']
        subplots_action = toolbar._actions.pop('configure_subplots')
        toolbar.removeAction(subplots_action)
        self._ui.grp_plot.layout().addWidget(toolbar)

    def help_requested(self):  # pragma: no cover
        """Called when the Help button is clicked."""
        webbrowser.open(self._help_url)

    def add_series(self):
        """Adds the XY line series to the plot."""
        if self._in_setup:
            return
        if not self._ax:
            self._ax = self._figure.add_subplot(111)
        # self.add_line_series(self.series_name, True)
        if self._ui.rbt_plot_computed_vs_observed.isChecked():
            self._add_computed_vs_observed()
        elif self._ui.rbt_plot_residual_vs_observed.isChecked():
            self._add_residual_vs_observed()
        else:
            self._add_time_series()
        self._canvas.draw()

    def setup_plot(self):
        """Sets up the plot."""
        self._figure = Figure()
        self._figure.set_tight_layout(True)  # Frames the plots
        self._canvas = FigureCanvas(self._figure)
        self._canvas.setMinimumWidth(100)  # So user can't resize it to nothing
        self._ui.hlay_grp_box.addWidget(self._canvas)
        self._add_navigation_toolbar()
        self.add_series()

    def _add_computed_vs_observed(self):
        """Adds a computed vs. observed plot."""
        self._ax.clear()
        self._ax.set_title("Computed vs. Observed")
        self._ax.grid(True)

        if self._model.rowCount() == 0:
            return

        # Add data to plot
        mx = -999999999
        mn = 999999999
        for index in self._model.data_frame.index.unique().tolist():
            df = self._model.data_frame[self._model.data_frame.index == index]
            x_column = df.iloc[:, ObsTargetModel.COL_IDX_OBSERVED].values
            y_column = df.iloc[:, ObsTargetModel.COL_IDX_COMPUTED].values
            self._ax.plot(x_column, y_column, label=df.iloc[0, ObsTargetModel.COL_IDX_NAME], marker='o', linestyle='')
            mx = max(mx, np.nanmax(x_column))
            mx = max(mx, np.nanmax(y_column))
            mn = min(mn, np.nanmin(x_column))
            mn = min(mn, np.nanmin(y_column))

        # Add diagonal line
        self._ax.plot([mn, mx], [mn, mx])

        # Legend
        legend = self._ax.legend(loc='best')
        legend.set_draggable(state=True)
        if not self._ui.chk_show_legend.isChecked():
            legend.remove()

        # Axis titles
        self._ax.set_xlabel(ObsTargetModel.COL_TEXT[ObsTargetModel.COL_IDX_OBSERVED])
        self._ax.set_ylabel(ObsTargetModel.COL_TEXT[ObsTargetModel.COL_IDX_COMPUTED])

    def _add_residual_vs_observed(self):
        """Adds a computed vs. observed plot."""
        self._ax.clear()
        self._ax.set_title("Residual vs. Observed")
        self._ax.grid(True)

        if self._model.rowCount() == 0:
            return

        # Add data to plot
        mx = -999999999.0
        mn = 999999999.0
        for index in self._model.data_frame.index.unique().tolist():
            df = self._model.data_frame[self._model.data_frame.index == index]
            x_column = df.iloc[:, ObsTargetModel.COL_IDX_OBSERVED].values
            y_column = df.iloc[:, ObsTargetModel.COL_IDX_RESIDUAL].values
            self._ax.plot(x_column, y_column, label=df.iloc[0, ObsTargetModel.COL_IDX_NAME], marker='o', linestyle='')
            mx = max(mx, np.nanmax(x_column))
            mn = min(mn, np.nanmin(x_column))

        # Add diagonal line
        self._ax.plot([mn, mx], [0.0, 0.0])

        # Legend
        legend = self._ax.legend(loc='best')
        legend.set_draggable(state=True)
        if not self._ui.chk_show_legend.isChecked():
            legend.remove()

        # Axis titles
        self._ax.set_xlabel(ObsTargetModel.COL_TEXT[ObsTargetModel.COL_IDX_OBSERVED])
        self._ax.set_ylabel(ObsTargetModel.COL_TEXT[ObsTargetModel.COL_IDX_RESIDUAL])

    def _add_time_series(self):
        """Adds a time series plot."""
        self._ax.clear()
        self._ax.set_title("Time Series")
        self._ax.grid(True)

        if self._model.rowCount() == 0:
            return

        # Show computed
        for index in self._model.data_frame.index.unique().tolist():
            df = self._model.data_frame[self._model.data_frame.index == index]
            x_column = df.iloc[:, ObsTargetModel.COL_IDX_TIME].values
            y_column = df.iloc[:, ObsTargetModel.COL_IDX_COMPUTED].values
            dates = mdates.date2num(x_column)
            name = df.iloc[0, ObsTargetModel.COL_IDX_NAME]
            self._ax.plot(dates, y_column, label=f'{name} computed', marker='o', linestyle='')

        # Show observed
        if self._ui.chk_show_observed.isChecked():
            for index in self._model.data_frame.index.unique().tolist():
                df = self._model.data_frame[self._model.data_frame.index == index]
                df = df.dropna(subset=['observed'])
                x_column = df.iloc[:, ObsTargetModel.COL_IDX_TIME].values
                y_column = df.iloc[:, ObsTargetModel.COL_IDX_OBSERVED].values
                dates = mdates.date2num(x_column)
                name = df.iloc[0, ObsTargetModel.COL_IDX_NAME]
                self._ax.plot(dates, y_column, label=f'{name} observed', marker='o', linestyle='')

        # X-axis format
        formatter = mdates.DateFormatter(self._time_format)  # Seems to be overridden by next line
        self._ax.xaxis.set_major_formatter(formatter)
        self._figure.autofmt_xdate()

        # Legend
        legend = self._ax.legend(loc='best')
        legend.set_draggable(state=True)
        if not self._ui.chk_show_legend.isChecked():
            legend.remove()

        # Axis titles
        self._ax.set_xlabel(ObsTargetModel.COL_TEXT[ObsTargetModel.COL_IDX_TIME])
        self._ax.set_ylabel('Value')

    def accept(self):
        """Called when the OK button is clicked."""
        self._save_splitter_geometry()
        super().accept()

    def reject(self):
        """Called when the Cancel button is clicked."""
        self._save_splitter_geometry()
        super().accept()
