"""Dialog for viewing the results of a check levee tool."""

# 1. Standard Python modules
import pickle

# 2. Third party modules
from matplotlib.backends.backend_qt5agg import FigureCanvas
from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar
from matplotlib.figure import Figure
from matplotlib.ticker import AutoLocator
from pandas.plotting import register_matplotlib_converters
from PySide2.QtCore import QSize, Qt
from PySide2.QtGui import QColor, QFont, QTextCursor
from PySide2.QtWidgets import (
    QAbstractItemView, QApplication, QDialogButtonBox, QHeaderView, QPushButton, QSplitter, QTextEdit
)

# 3. Aquaveo modules
from xms.core.filesystem import filesystem as xfs
from xms.data_objects.parameters import Coverage
from xms.guipy.dialogs.file_selector_dialogs import get_save_filename
from xms.guipy.dialogs.message_box import message_with_ok
from xms.guipy.dialogs.xms_parent_dlg import XmsDlg
from xms.guipy.models.qx_pandas_table_model import QxPandasTableModel
from xms.guipy.settings import SettingsManager
from xms.guipy.widgets import widget_builder

# 4. Local modules
from xms.adcirc.gui.check_levee_tool_results_dialog_ui import Ui_CheckLeveeToolResultsDialog
from xms.adcirc.gui.line_map_viewer_dialog import LineMapViewerDialog
from xms.adcirc.mapping.mapping_util import check_levee_bc_file, check_levee_results_df_file
from xms.adcirc.tools import levee_check_tool_consts as const


def load_tool_output():
    """Load levee check results DataFrame from H5 file written by an ADCIRC check levee tool.

    Returns:
        (:obj:`tuple(pandas.DataFrame,Coverage,str)`): The results DataFrame or None on error, the BC coverage
        geometry dump, formatted log output of gloabl warnings and errors
    """
    df = None
    cov = None
    global_log = ''
    try:
        # Read the pickled plot DataFrame and global log output
        filename = check_levee_results_df_file()
        with open(filename, 'rb') as f:
            data = pickle.load(f)
        xfs.removefile(filename)
        df = data['df']
        global_log = data['global_log']
        # Read the coverage geometry dump
        cov = Coverage(check_levee_bc_file())  # Coverage will cleanup the file itself
    except Exception:
        pass
    return df, cov, global_log


def format_log_output(log_output, text_edit):
    """Change the color of text in log output results pane based on formatters defined in ProcessFeedbackDlg.

    Args:
        log_output (:obj:`str`): The logging output for a levee
        text_edit (:obj:`QLineEdit`): The Qt line edit widget to append the text to.
    """
    # Split separate messages on delimiter specifier so they can be formatted differently.
    messages = log_output.split(const.LOG_MESSAGE_DELIM)
    for message in messages:
        if message.startswith(const.LOG_LEVEL_WARNING):  # Change text color to orange if warning during check
            message = message.replace(const.LOG_LEVEL_WARNING, '')
            text_edit.setTextColor(QColor(255, 127, 39))
        elif message.startswith(const.LOG_LEVEL_ERROR):  # Change text color to red if error during check
            message = message.replace(const.LOG_LEVEL_ERROR, '')
            text_edit.setTextColor(QColor(255, 0, 0))
        elif message.startswith(const.LOG_LEVEL_SUCCESS):  # Change text color to green if checks passed
            message = message.replace(const.LOG_LEVEL_SUCCESS, '')
            text_edit.setTextColor(QColor(0, 200, 0))
        else:  # Restore normal black text color if message level specifier not present
            text_edit.setTextColor(QColor(0, 0, 0))
        # If you want to change both the color and weight of the text, the level specifier should come before the
        # bold specifier.
        if message.startswith(const.LOG_MESSAGE_BOLD):
            message = message.replace(const.LOG_MESSAGE_BOLD, '')
            text_edit.setFontWeight(QFont.Bold)
        else:  # Restore normal font weight if bold specifier not present
            text_edit.setFontWeight(QFont.Normal)
        text_edit.append(message)
    # Move cursor to the beginning of the text
    cursor = text_edit.textCursor()
    cursor.movePosition(QTextCursor.Start, QTextCursor.MoveAnchor, 1)
    text_edit.setTextCursor(cursor)


class QxPandasTableModelCheapHeaders(QxPandasTableModel):
    """Class derived from QxPandasTableModel to avoid expensive headerData calls."""
    def __init__(self, data_frame, parent):
        """Initializes the class.

        Args:
            data_frame (:obj:`pandas.DataFrame`): The pandas DataFrame
            parent (:obj:`QWidget`): The Qt parent
        """
        super().__init__(data_frame, parent)
        self._column_names = data_frame.columns.values.tolist()

    def headerData(self, section, orientation, role=Qt.DisplayRole):  # noqa: N802
        """Returns the data for the given role and section in the header.

        Args:
            section (:obj:`int`): The section.
            orientation (:obj:`Qt.Orientation`): The orientation.
            role (:obj:`int`): The role.

        Returns:
            The data.
        """
        if role == Qt.DisplayRole:
            if orientation == Qt.Horizontal:
                return self._column_names[section]
            else:
                return section + 1  # Switch default 0-base sequential pandas Index to 1-base row number
        return super().headerData(section, orientation, role)


class CheckLeveeToolResultsDialog(XmsDlg):
    """Dialog for viewing the results of a check levee tool."""
    def __init__(self, parent):
        """Constructor.

        Args:
            parent (:obj:`QObject`): The parent object
        """
        super().__init__(parent, 'xms.adcirc.gui.check_levee_tool_results_dialog')
        self.ui = None
        self.splitter = None
        self.model = None
        self.figure = None  # matplotlib.figure Figure
        self.canvas = None  # matplotlib.backends.backend_qt5agg FigureCanvas
        self.ax = None  # matplotlib Axes
        self.last_color = None
        self.tool_type = None  # Read from DataFrame
        self.series_name = None  # Dependent on tool_type
        self.y_label = 'Elevation'  # Will append units based on current display projection vertical units
        self.global_log = ''
        self._filter_widgets = None
        self.in_setup = False
        # Stuff for the map viewer
        self._lines = []
        self._map_viewer = None

        register_matplotlib_converters()  # I think this is required by mdates.date2num call below
        df, do_cov, self.global_log = load_tool_output()
        if df is not None and do_cov is not None:
            self._init_from_tool_output(df, do_cov)
            self.in_setup = True
            self.ui = Ui_CheckLeveeToolResultsDialog()
            self._setup_ui(df)

        self.ui.tog_adjusted.setChecked(True)
        self.ui.tog_adjusted_with_issue.setChecked(True)

    def _init_from_tool_output(self, df, do_cov):
        """Initialize from DataFrame written by tool once we know we have read something valid.

        Args:
            df (:obj:`pandas.DataFrame`): The tool output DataFrame
            do_cov (:obj:`Coverage`): The data_objects BC coverage geometry
        """
        self.tool_type = df.attrs['tool_type']
        self.y_label += ' (m)' if 'METER' in df.attrs['vert_units'] else ' (ft)'
        self._map_viewer = LineMapViewerDialog(parent=self, do_cov=do_cov, results_df=df)
        self._map_viewer.show()

    """
    UI setup
    """

    def _setup_ui(self, df):
        """Sets up the UI.

        Args:
            df (:obj:`pandas.DataFrame`): Pandas DataFrame
        """
        self.ui.setupUi(self)
        self._filter_widgets = {  # Create lookup map for filters
            self.ui.tog_adjusted: const.CHECK_STATUS_ADJUSTED,
            self.ui.tog_adjusted_with_issue: const.CHECK_STATUS_ADJUSTED_ISSUE,
            self.ui.tog_unadjusted: const.CHECK_STATUS_UNADJUSTED,
            self.ui.tog_unadjusted_with_issue: const.CHECK_STATUS_UNADJUSTED_ISSUE,
        }
        self._setup_model(df)
        self._setup_table()
        self._setup_plot()
        self._add_splitter()
        self._add_buttons()
        self._connect_slots()
        format_log_output(self.global_log, self.ui.txt_global_log)  # Load and format global warning/errors
        self.in_setup = False
        if self.model.rowCount() > 0:  # Select the first levee
            self.ui.tbl_xy_data.selectRow(0)

    """
    Data model/view setup
    """

    def _setup_model(self, df):
        """Sets up the model.

        Args:
            df (:obj:`pandas.DataFrame`): The results DataFrame written by the check levee tool
        """
        self.model = QxPandasTableModelCheapHeaders(df, self)
        self.model.set_read_only_columns({i for i in range(const.RESULTS_START_HIDE_COL_IDX)})
        self.model.set_show_nan_as_blank(True)
        # Since our table is read-only setting the check status column as a combobox gives us a label that maps from
        # the integer enum value we store.
        self.model.set_combobox_column(const.RESULTS_STATUS_COL_IDX, list(const.CHECK_STATUS_TEXT.values()))

    def _setup_table(self):
        """Sets up the table."""
        self.ui.tbl_xy_data.size_to_contents = True
        widget_builder.style_table_view(self.ui.tbl_xy_data)
        self.ui.tbl_xy_data.setModel(self.model)
        # Hide all the curve columns as well as the logging output column and status string column
        for i in range(const.RESULTS_START_HIDE_COL_IDX, self.model.columnCount()):
            self.ui.tbl_xy_data.setColumnHidden(i, True)
        # Resize columns and set to interactive resize mode
        horizontal_header = self.ui.tbl_xy_data.horizontalHeader()
        horizontal_header.setSectionResizeMode(QHeaderView.Interactive)
        self.ui.tbl_xy_data.resizeColumnsToContents()
        horizontal_header.setStretchLastSection(True)
        # Set selection behavior to single rows
        self.ui.tbl_xy_data.setSelectionMode(QAbstractItemView.SingleSelection)
        self.ui.tbl_xy_data.setSelectionBehavior(QAbstractItemView.SelectRows)

    """
    Plot setup
    """

    def _setup_plot(self):
        """Sets up the plot."""
        self.series_name = const.TOOL_SERIES_NAMES[self.tool_type]
        self.figure = Figure()
        self.figure.set_tight_layout(True)  # Frames the plots
        self.canvas = FigureCanvas(self.figure)
        self.canvas.setMinimumWidth(100)  # So user can't resize it to nothing
        self.ui.vlay_plots.addWidget(self.canvas)
        # self._add_navigation_toolbar()
        self._add_series()

    def _add_navigation_toolbar(self):
        """Add the built-in matplotlib navigation toolbar but hide stuff we don't want."""
        toolbar = NavigationToolbar(self.canvas, self)
        # Remove the "Configure subplots" button. We aren't really sure what this does.
        toolbar.toolitems = [toolitem for toolitem in toolbar.toolitems if toolitem[0] != 'Subplots']
        subplots_action = toolbar._actions.pop('configure_subplots')
        toolbar.removeAction(subplots_action)
        self.ui.vlay_plots.addWidget(toolbar)

    def _add_series(self):
        """Adds the XY line series to the plot."""
        # Clear/initialize plot
        if not self.ax:
            self.ax = self.figure.add_subplot(111)
        self.ax.clear()
        self.ax.set_title(self.series_name)
        self.ax.grid(True)
        # Find the selected levee (row in results table) to plot.
        selection = self.ui.tbl_xy_data.selectionModel().selectedIndexes()
        row = selection[0].row() if selection else -1
        # Add a series for each column in the DataFrame if we have a valid row selection.
        if row > -1 and self.model.rowCount() > 0:
            x_column = self.model.data_frame.iloc[row, const.RESULTS_START_CURVE_COL_IDX]
            if len(x_column) > 0:  # Only plot curves if the levee has data for them
                for i in range(const.RESULTS_START_CURVE_COL_IDX + 1, self.model.columnCount()):
                    # Get the default value for the y-column, if there is one
                    self._add_line_series(row, i, x_column)
                self.ax.legend(loc='best')
        self.canvas.draw()

    def _add_line_series(self, row, data_column, x_column):
        """Adds an XY line series to the plot.

        Args:
            row (:obj:`int`): 0-based index of the selected row (levee) to plot
            data_column (:obj:`int`): Index of the column to plot against the X column
            x_column (:obj:`numpy.ndarray`): The x-axis values
        """
        # Extract data for the series
        y_column = self.model.data_frame.iloc[row, data_column]
        series_label = self.model.data_frame.columns[data_column]
        # Plot the data
        line = self.ax.plot(x_column, y_column, label=series_label)
        # If we are drawing an original line vs. an adjusted, draw the original in the same color and dashed. Adjusted
        # lines should appear in the DataFrame in the column immediately before the original column and the original
        # column name should contain the text 'Original'
        if 'Original' in series_label:
            line[0].set_color(self.last_color)
            line[0].set_linestyle('--')
        else:
            self.last_color = line[0].get_color()
            if series_label == const.RESULTS_CHECK_ELEVATION_COL_NAME:
                # If this is a check elevation line, draw it thicker because it will probably overlap
                line[0].set_linewidth(5.0)
        self.ax.xaxis.set_major_locator(AutoLocator())
        # Set axis titles
        x_label = self.model.data_frame.columns[const.RESULTS_START_CURVE_COL_IDX]
        self.ax.set_xlabel(x_label)
        self.ax.set_ylabel(self.y_label)

    """
    Widget setup
    """

    def _add_splitter(self):
        """Adds a QSplitter between the tables so the sizes can be adjusted."""
        # The only way this seems to work right is to parent it to
        # self and then insert it into the layout.
        self.splitter = QSplitter(self)
        self.splitter.setOrientation(Qt.Horizontal)
        self.splitter.addWidget(self.ui.wid_left)
        self.splitter.addWidget(self.ui.grp_results)
        # Just use a fixed starting width of 300 for the table for now
        self.splitter.setSizes([300, 600])
        self.splitter.setChildrenCollapsible(False)
        self.splitter.setStyleSheet(
            'QSplitter::handle:horizontal { background-color: lightgrey; }'
            'QSplitter::handle:vertical { background-color: lightgrey; }'
        )
        pos = self.ui.hlay_main.indexOf(self.ui.grp_results)
        self.ui.hlay_main.insertWidget(pos, self.splitter)

    def _connect_slots(self):
        """Connect Qt widget signal/slots."""
        self.ui.tbl_xy_data.selectionModel().selectionChanged.connect(self.on_levee_changed)
        self.ui.tog_adjusted.stateChanged.connect(self.filter_rows)
        self.ui.tog_adjusted_with_issue.stateChanged.connect(self.filter_rows)
        self.ui.tog_unadjusted.stateChanged.connect(self.filter_rows)
        self.ui.tog_unadjusted_with_issue.stateChanged.connect(self.filter_rows)

    def _add_buttons(self):
        """Adds a hidden button with the default focus to the standard QDialogButtonBox.

        Notes:
            This is a workaround because Qt gives focus to the 'Ok' button by default, so pressing enter in the
            spreadsheet often leads to inadvertently accepting changes in the dialog.
        """
        map_viewer_button = QPushButton('Map Viewer')
        map_viewer_button.clicked.connect(self._map_viewer.show)
        map_viewer_button.setDefault(True)
        export_log_button = QPushButton('Export Log...')
        export_log_button.clicked.connect(self.export_log)
        self.ui.button_box.addButton(map_viewer_button, QDialogButtonBox.ActionRole)
        self.ui.button_box.addButton(export_log_button, QDialogButtonBox.ActionRole)
        self.ui.button_box.addButton(QDialogButtonBox.StandardButton.Ok)

    """
    Levee row filtering methods
    """

    def _show_statuses(self):
        """Get the levee check statuses that are currently visible.

        Returns:
            (:obj:`list[int]`): The levee check status enum values that are not hidden.

            If no filters being applied, returns an empty list.
        """
        return [status_enum for widget, status_enum in self._filter_widgets.items() if widget.isChecked()]

    def _show_all_levees(self):
        """Show all rows in the levee results table (all levee pairs that were checked)."""
        num_rows = self.model.rowCount()
        for i in range(num_rows):
            self.ui.tbl_xy_data.setRowHidden(i, False)
        self._map_viewer.hidden_rows = set()  # Let the map viewer know we are not filtering any levees.
        if num_rows:  # Select the first levee pair (row) in the results table, if we have one.
            self.ui.tbl_xy_data.selectRow(0)  # Will trigger a redraw of the map viewer plot.

    def _filter_levees_by_check_status(self, filter_mask):
        """Show/hide rows based on the check status of the levee pair.

        Args:
            filter_mask (:obj:`pandas.DataFrame`): Mask array where visible rows are flagged as False, and hidden are
                flagged as True.
        """
        # Get the row indices in a parallel array with the filter mask flags and convert to Python lists for iteration.
        row_indices = filter_mask.index.values.tolist()
        filtered_rows = filter_mask.values.tolist()
        first_visible = None  # Keep track of the first visible row so we can select it
        hidden_rows = set()  # Keep track of hidden row indices so we can pass them to the map viewer
        # Show/hide rows in the results table based on the filter mask flags.
        for row_index, is_filtered in zip(row_indices, filtered_rows):
            self.ui.tbl_xy_data.setRowHidden(row_index, is_filtered)
            if is_filtered:  # If row is filtered, don't draw it in the map viewer
                hidden_rows.add(row_index)
            elif first_visible is None:  # If this is the first visible row, mark it as the one to select
                first_visible = row_index
        # Tell the map viewer which rows are currently hidden and update the selection in the results table. Will
        # trigger a redraw of the map viewer plot.
        self._map_viewer.hidden_rows = hidden_rows
        if first_visible is not None:  # If there is at least one visible row, select the first one in the table
            self.ui.tbl_xy_data.selectRow(first_visible)
        else:  # If there are no visible rows, clear the table selection
            self.ui.tbl_xy_data.clearSelection()

    """
    Data I/O
    """

    def _restore_splitter_geometry(self):
        """Restore the position of the splitter."""
        settings = SettingsManager()
        splitter = settings.get_setting('xmsadcirc', f'{self._dlg_name}.splitter')
        if not splitter:
            return
        splitter_sizes = [int(size) for size in splitter]
        self.splitter.setSizes(splitter_sizes)

    def _save_splitter_geometry(self):
        """Save the current position of the splitter."""
        settings = SettingsManager()
        settings.save_setting('xmsadcirc', f'{self._dlg_name}.splitter', self.splitter.sizes())

    """
    Slots
    """

    def on_levee_changed(self, selected, deselected):
        """Change the current levee when the user changes the selected row.

        Args:
            selected (:obj:`list`): The newly selected index. Should only ever be size 0 or 1
            deselected (:obj:`list`): The previously selected index. Should only ever be size 0 or 1
        """
        if self.in_setup:
            return  # Avoid recursion or excessive, unnecessary calls
        QApplication.setOverrideCursor(Qt.WaitCursor)
        self.ui.txt_log.clear()  # Clear the log output tab
        if not selected or not selected[0].isValid():
            # Clear the curve plots and log output tab and frame map viewer to global extents if selection is cleared.
            self.ax.clear()
            self.canvas.draw()
            self._map_viewer.draw_feature_lines([])
            self.ui.lbl_status.setText('')
            self.ui.lbl_issues.setText('')
        else:
            selected_timestep_idx = selected[0].indexes()[0].row()
            previous_timestep_idx = -1
            if deselected:
                previous_timestep_idx = deselected[0].indexes()[0].row()
            if selected_timestep_idx != previous_timestep_idx:
                self._add_series()  # Update the levee results curves being plotted
                # Zoom to the selected levee in the map viewer.
                arc1_id = self.model.data_frame['Arc 1 ID'].iat[selected_timestep_idx]
                arc2_id = self.model.data_frame['Arc 2 ID'].iat[selected_timestep_idx]
                self._map_viewer.draw_feature_lines([arc1_id, arc2_id])
                log_output = self.model.data_frame['log_output'].iat[selected_timestep_idx]
                format_log_output(log_output, self.ui.txt_log)
                status_string = self.model.data_frame['status_string'].iat[selected_timestep_idx]
                split_text = status_string.split(const.LOG_MESSAGE_DELIM)
                self.ui.lbl_status.setText(split_text[0])
                self.ui.lbl_issues.setText(split_text[1])
        QApplication.restoreOverrideCursor()

    def filter_rows(self, _):
        """Show/hide rows in the levee results table based on check status.

        Args:
            '_' (:obj:`Qt.CheckState`): Unused. Just need signature to match for Qt signal/slot
        """
        if self.in_setup:
            return  # Avoid recursion or excessive, unnecessary calls

        QApplication.setOverrideCursor(Qt.WaitCursor)
        # First check if all the filter checkboxes are unchecked, in which case we want to show all rows and draw all
        # arcs in the map viewer.
        show_statuses = self._show_statuses()
        if not show_statuses:
            self._show_all_levees()  # This will select the first levee row, which triggers a redraw of the map viewer.
        else:  # One or more filters are being applied, only show rows (levees) with checked status types
            previous_selection = self.ui.tbl_xy_data.selectionModel().selectedRows()
            filter_mask = ~self.model.data_frame.Status.isin(show_statuses)  # Invert the mask so hidden is True
            self._filter_levees_by_check_status(filter_mask)
            # Force a redraw of the map view if the selection doesn't change. We may have changed the visible arcs
            # without emitting the selection changed signal.
            current_selection = self.ui.tbl_xy_data.selectionModel().selectedRows()
            if previous_selection == current_selection:
                if current_selection:
                    arc1_id = self.model.data_frame['Arc 1 ID'].iat[current_selection[0].row()]
                    arc2_id = self.model.data_frame['Arc 2 ID'].iat[current_selection[0].row()]
                    selected_levee = [arc1_id, arc2_id]
                else:
                    selected_levee = []
                self._map_viewer.draw_feature_lines(selected_levee)
        QApplication.restoreOverrideCursor()

    def export_log(self):
        """Export the levee results log to a file."""
        selected_filter = 'Text file (*.txt)'
        filename = get_save_filename(self, selected_filter, f'{selected_filter};;All types (*.*)')
        if not filename:
            return  # User cancelled
        QApplication.setOverrideCursor(Qt.WaitCursor)
        with open(filename, 'w') as f:
            # Export the global warnings/errors
            f.write('#######################################\nGLOBAL OUTPUT\n#######################################\n')
            f.write(self.ui.txt_global_log.toPlainText())
            # Export levee-specific output
            f.write('\n#######################################\nLEVEE OUTPUT\n#######################################')
            text_edit = QTextEdit()  # Temporary widget to fill with formatted log output
            for idx, row in self.model.data_frame.iterrows():
                f.write(
                    f'\n\n***************************************\nLevee row: {idx + 1}\nFeature arc IDs: '
                    f'{row["Arc 1 ID"]}, {row["Arc 2 ID"]}\n***************************************\n'
                )
                text_edit.clear()
                format_log_output(row['log_output'], text_edit)
                f.write(text_edit.toPlainText())
        QApplication.restoreOverrideCursor()

    """
    Qt overloads
    """

    def sizeHint(self):  # noqa: N802
        """Overridden method to help the dialog have a good minimum size.

        Returns:
            (:obj:`QSize`): Size to use for the initial dialog size.
        """
        return QSize(800, 500)

    def showEvent(self, event):  # noqa: N802
        """Restore last position and geometry when showing dialog."""
        super().showEvent(event)
        self._restore_splitter_geometry()

    def accept(self):
        """Override default accept slot to update persistent dataset."""
        if self._map_viewer is not None:
            self._map_viewer.accept()
        self._save_splitter_geometry()
        super().accept()

    def reject(self):
        """Save position and geometry when closing dialog."""
        if self._map_viewer is not None:
            self._map_viewer.accept()
        self._save_splitter_geometry()
        super().reject()

    def exec(self):
        """Overload to abort bringing up the dialog if we failed to load the tool output DataFrame."""
        if self.model is None:
            msg = 'Unable to load results from tool output. Check log window for errors.'
            message_with_ok(
                parent=self.parent(), message=msg, app_name='SMS', icon='Critical', win_icon=self.parent().windowIcon()
            )
            return False
        elif self.model.data_frame.empty:
            return True
        return super().exec_()
