"""Dialog for viewing the results of a check levee tool."""

# 1. Standard Python modules
import datetime
import math
import os

# 2. Third party modules
import folium
from matplotlib.backends.backend_qt5agg import FigureCanvasQT, NavigationToolbar2QT as NavigationToolbar
from matplotlib.figure import Figure
from PySide2.QtCore import QSize, QUrl
from PySide2.QtWebEngineWidgets import QWebEngineView
from PySide2.QtWidgets import (QAbstractItemView, QTableWidgetItem)
from shapely.geometry import LineString, Point as shPt
from shapely.ops import nearest_points

# 3. Aquaveo modules
from xms.api.dmi import Query, XmsEnvironment as XmEnv
from xms.api.tree import tree_util
try:
    from xms.data_objects.parameters import Arc, Coverage, FilterLocation, Point, Polygon
except ImportError:  # pragma no cover - optional import
    Arc = object
    Coverage = object
    FilterLocation = object
    Point = object
    Polygon = object
from xms.gdal.utilities import gdal_utils as gu
from xms.guipy.dialogs.xms_parent_dlg import XmsDlg
from xms.guipy.validators.qx_double_validator import QxDoubleValidator

# 4. Local modules
from xms.gencade.components.mapped_grid_component import MappedGridComponent
from xms.gencade.gui.view_solution_dialog_ui import Ui_ViewSolutionDialog

# Columns in the primary table
COL_TYPE = 0
COL_ID = 1
COL_LENGTH = 2
COL_NUM_SEGMENTS = 3
COL_NUM_ARCS = 4
COL_AREA = 5
COL_PERIMETER = 6
COL_GEOM_LIST_IDX = 7


class ViewSolutionDialog(XmsDlg):
    """Dialog for the viewing solution files."""

    def __init__(self, sim_data, parent=None):
        """Initializes the dialog, sets up the ui.

        Args:
            sim_data (:obj:`SimData`): data for the dialog
            parent (:obj:`QWidget`): Parent window
        """
        super().__init__(parent, 'xms.gencade.gui.view_solution_dialog')
        self._help_url = 'https://cirpwiki.info/wiki/GenCade_2.0_View_Solution_Dialog'

        self.figure = None  # matplotlib.figure Figure
        self.canvas = None  # matplotlib.backends.backend_qt5agg FigureCanvasQT
        self.ax = None  # matplotlib Axes
        self.toolbar = None  # matplotlib navigation toolbar

        self.sim_data = sim_data
        self._solutions = dict()
        self._scale_factors = dict()
        self._proj_path = ''
        self.setWindowTitle('GenCade View Solution')

        self._query = Query()
        projection = self._query.display_projection
        self._wkt = projection.well_known_text
        self._coord_sys = projection.coordinate_system.upper()
        self._horiz_units = projection.horizontal_units.upper()
        self._vert_units = projection.vertical_units.upper()
        # Set up a transformation from input non-geographic to WGS 84
        self._transform = gu.get_coordinate_transformation(self._wkt, gu.wkt_from_epsg(4326))
        sim_uuid = self._query.parent_item_uuid()
        self._sim_item = tree_util.find_tree_node_by_uuid(self._query.project_tree, sim_uuid)
        mapped_grid_item = tree_util.descendants_of_type(self._sim_item, only_first=True,
                                                         unique_name='MappedGridComponent')
        self._mapped_grid = MappedGridComponent(mapped_grid_item.main_file)
        self._grid_lines = []
        self._get_grid_drawing_pts()

        self._init_shore_vals = []
        # get the initial shoreline
        # self._sl = None
        # self._get_shoreline_drawing_pts()

        temp_dir = XmEnv.xms_environ_process_temp_directory()
        self._url_file = os.path.join(temp_dir, f'view_solution_{os.getpid()}.html')

        self.ui = Ui_ViewSolutionDialog()
        self._setup_ui()

    def _reproject_points(self, pts):
        """Reproject the line locations from a projected system to geographic.

        Args:
            pts (:obj:`list`): List of the line locations [[(x,y),...],...]
        """
        transform_pts = []
        for point in pts:
            coord = self._transform.TransformPoint(point[0], point[1])
            transform_pts.append((coord[0], coord[1]))
        return transform_pts

    """
    UI setup
    """
    def _setup_ui(self):
        """Sets up the UI."""
        self.ui.setupUi(self)

        dbl_validator = QxDoubleValidator(parent=self, bottom=0.0)
        self.ui.edt_y_scale.setValidator(dbl_validator)
        self.ui.edt_y_scale.setText('1.0')

        self.ui.edt_q_scale.setValidator(dbl_validator)
        self.ui.edt_q_scale.setText('1.0')

        self.ui.cbx_map.addItems(['None', 'ESRI Street Map', 'ESRI World Imagery'])
        self.ui.cbx_map.setCurrentIndex(1)

        # set up the list of solution files
        self._setup_list()
        self._get_solution_files()

        # set up the action preview window
        self._setup_viewer()

        # set up the plot window
        self._setup_plot()

        # set up the table
        self.ui.tbl_dataset.setColumnCount(2)
        self.ui.tbl_dataset.setHorizontalHeaderLabels(['Distance Along Grid', 'Solution Value'])
        self.ui.tbl_dataset.resizeColumnsToContents()
        self.ui.tbl_dataset.setEditTriggers(QAbstractItemView.NoEditTriggers)

        self._on_change_sel_dataset()

        # connect the slots
        self._connect_slots()

    """
    Data model/view setup
    """
    def _setup_viewer(self):
        """Sets up the model."""
        self._viewer = QWebEngineView(self)
        self.ui.vlay_map.addWidget(self._viewer)

    def _setup_plot(self):
        """Sets up the model."""
        self.figure = Figure()
        self.figure.set_layout_engine(layout='tight')
        self.canvas = FigureCanvasQT(self.figure)
        self.canvas.setMinimumWidth(300)  # So user can't resize it to nothing

        self.toolbar = NavigationToolbar(self.canvas, self)
        # Remove the "Configure subplots" button. We aren't really sure what this does.
        for x in self.toolbar.actions():
            if x.text() == 'Subplots':
                self.toolbar.removeAction(x)

        self.ui.vlay_plot.addWidget(self.toolbar)
        self.ui.vlay_plot.addWidget(self.canvas)
        self.ax = self.figure.add_subplot(111)

    def _setup_list(self):
        """Sets up the table."""
        # Set selection behavior to single rows
        self.ui.list_datasets.setSelectionMode(QAbstractItemView.SingleSelection)
        self.ui.list_datasets.setSelectionBehavior(QAbstractItemView.SelectRows)

        p = self.ui.list_datasets.palette()
        p.setBrush(p.Inactive, p.Highlight, p.brush(p.Highlight))
        self.ui.list_datasets.setPalette(p)

    def _get_solution_files(self):
        """Fill the table with the solution filenames."""
        self._proj_path = os.path.splitext(self._query.xms_project_path)[0] + '_models\\GenCade\\' + self._sim_item.name
        proj_name = os.path.splitext(os.path.basename(self._query.xms_project_path))[0]
        file_list = os.listdir(self._proj_path)

        mql = proj_name + '.mql'
        if mql in file_list:
            self._read_dataset_file(mql, 'Mean Transport Left', 'Average Mean Transport Left', 'q', True)

        mqn = proj_name + '.mqn'
        if mqn in file_list:
            self._read_dataset_file(mqn, 'Mean Net Transport', 'Average Mean Transport', 'q', True)

        mqr = proj_name + '.mqr'
        if mqr in file_list:
            self._read_dataset_file(mqr, 'Mean Transport Right', 'Average Mean Transport Right', 'q', True)

        qtr = proj_name + '.qtr'
        if qtr in file_list:
            self._read_dataset_file(qtr, 'Transport Rate', '', 'y', False)

        off = proj_name + '.off'
        if off in file_list:
            self._read_dataset_file(off, 'Calculated Offshore Contour', '', 'y', False)

        shi = proj_name + '.shi'
        if shi in file_list:
            self._read_initial_shoreline(shi)

        slo = proj_name + '.slo'
        if slo in file_list:
            self._read_slo(slo)

        self.ui.list_datasets.addItems(self._solutions.keys())
        self.ui.list_datasets.setCurrentRow(0)
        return

    def _read_dataset_file(self, filename, dset_name, ave_dset_name, scale_factor, has_average):
        """Read the mql solution file.

        Args:
            filename (str): Filename to read in
            dset_name (str): Name of the dataset
            ave_dset_name (str): Name of the average dataset
            scale_factor (str): Scale factor use (q or y)
            has_average (bool): Has an average dataset at the end
        """
        data, ave_data, _ = self._read_file(filename, has_average)
        if len(data) > 0:
            self._solutions[dset_name] = data
            self._scale_factors[dset_name] = scale_factor
        if len(ave_data) > 0:
            self._solutions[ave_dset_name] = ave_data
            self._scale_factors[ave_dset_name] = scale_factor

    def _read_slo(self, filename):
        """Read the slo solution file.

        Args:
            filename (str): Filename to read in
        """
        data, _, dates = self._read_file(filename, False)
        if len(data.keys()) == 0:
            return

        self._solutions['Shoreline'] = data
        self._scale_factors['Shoreline'] = 'y'

        if len(self._init_shore_vals) == 0:
            return

        change_data = {}
        rate_data = {}
        start_date_str = self.sim_data.model.attrs['start_date']
        start = datetime.date(int(start_date_str[:4]), int(start_date_str[5:7]), int(start_date_str[8:10]))
        for i, ts in enumerate(data.keys()):
            change_data[ts] = [cur_val - self._init_shore_vals[j] for j, cur_val in enumerate(data[ts])]
            num_days = (dates[i] - start).days
            if num_days == 0:
                rate_data[ts] = [0] * len(data[ts])
            else:
                rate_data[ts] = [(cur_val - self._init_shore_vals[j]) / num_days for j, cur_val in enumerate(data[ts])]

        self._solutions['Shoreline Change'] = change_data
        self._scale_factors['Shoreline Change'] = 'y'

        self._solutions['Shoreline Rate of Change'] = rate_data
        self._scale_factors['Shoreline Rate of Change'] = 'y'

    def _read_file(self, filename, last_is_average=True):
        """Read the solution file.

        Args:
            filename (str): Filename to read in
            last_is_average (bool): True if the last set of data is the average
        """
        data = dict()
        ave_data = dict()
        dates = []
        file_with_path = self._proj_path + '\\' + filename
        with open(file_with_path, 'r') as f:
            lines = f.readlines()
        if not lines:
            return data, ave_data

        # get rid of the first line
        lines.pop(0)
        num_lines = len(lines)
        for i, line in enumerate(lines):
            line_vals = line.split()
            ts = line_vals.pop(0)
            ts_vals = [float(val) for val in line_vals]
            if i < num_lines - 1 and num_lines > 1 or last_is_average is False:
                dt = datetime.date(int(ts[:4]), int(ts[4:6]), int(ts[6:8]))
                dates.append(dt)
                data[dt.isoformat()] = ts_vals
            else:
                # this is the last one and it is the average - assign it a ts of 0
                ave_data[0] = ts_vals

        return data, ave_data, dates

    def _read_initial_shoreline(self, filename):
        """Read the solution file.

        Args:
            filename (str): Filename to read in
        """
        file_with_path = self._proj_path + '\\' + filename
        with open(file_with_path, 'r') as f:
            lines = f.readlines()
        if not lines:
            return

        # get rid of the first four lines
        lines = lines[4:]
        for line in lines:
            line_vals = line.split()
            self._init_shore_vals.extend([float(val) for val in line_vals])

    """
    Widget setup
    """
    def _connect_slots(self):
        """Connect Qt widget signal/slots."""
        # spreadsheets
        self.ui.list_datasets.itemSelectionChanged.connect(self._on_change_sel_dataset)
        self.ui.list_time_steps.itemSelectionChanged.connect(self._update_results)
        self.ui.edt_y_scale.editingFinished.connect(self._update_results)
        self.ui.edt_q_scale.editingFinished.connect(self._update_results)
        self.ui.cbx_map.currentIndexChanged.connect(self._update_results)

        self.ui.buttonBox.accepted.connect(self.accept)
        self.ui.buttonBox.helpRequested.connect(self.help_requested)

    """
    Slots
    """
    def _on_change_sel_dataset(self):
        """Show/hide rows in the levee results table based on check status."""
        self.ui.list_time_steps.blockSignals(True)
        self._sel_dataset = self.ui.list_datasets.currentItem().text()

        # update the timestep list
        ts_list = list(self._solutions[self._sel_dataset].keys())
        self.ui.list_time_steps.clear()
        if ts_list[0] != '0':
            self.ui.list_time_steps.addItems(ts_list)
            self.ui.list_time_steps.setCurrentRow(0)
        self.ui.list_time_steps.setEnabled(ts_list[0] != '0')

        # enable/disable scale factor
        is_q = self._scale_factors[self._sel_dataset] == 'q'
        self.ui.edt_q_scale.setEnabled(is_q)
        self.ui.edt_y_scale.setEnabled(not is_q)

        self._update_results()

        self.ui.list_time_steps.blockSignals(False)

    def _update_results(self):
        """Show/hide rows in the levee results table based on check status."""
        # get the selected dataset and time step
        if len(self._solutions[self._sel_dataset]) > 1:
            self._sel_ts = self.ui.list_time_steps.currentItem().text()
            self._cur_sol_vals = self._solutions[self._sel_dataset][self._sel_ts]
        else:
            self._cur_sol_vals = self._solutions[self._sel_dataset][0]

        self._update_map()
        self._update_plot()
        self._update_table()

    def _update_plot(self):
        """Show/hide rows in the levee results table based on check status."""
        self.ax.clear()
        if len(self._cur_sol_vals) == len(self._dist_along_grid):
            self.ax.plot(self._dist_along_grid, self._cur_sol_vals, label=self._sel_dataset)
        elif len(self._cur_sol_vals) == len(self._dist_along_grid_mid_pts):
            self.ax.plot(self._dist_along_grid_mid_pts, self._cur_sol_vals, label=self._sel_dataset)
        else:
            pass
        self.ax.set_xlabel('Distance along grid')
        self.canvas.draw()

    def _update_table(self):
        """Show/hide rows in the levee results table based on check status."""
        self.ui.tbl_dataset.setRowCount(len(self._cur_sol_vals))
        if len(self._cur_sol_vals) == len(self._dist_along_grid):
            for row in range(len(self._cur_sol_vals)):
                self.ui.tbl_dataset.setItem(row, 0, QTableWidgetItem(f'{self._dist_along_grid[row]:.3f}'))
                self.ui.tbl_dataset.setItem(row, 1, QTableWidgetItem(f'{self._cur_sol_vals[row]:.3f}'))
        elif len(self._cur_sol_vals) == len(self._dist_along_grid_mid_pts):
            for row in range(len(self._cur_sol_vals)):
                self.ui.tbl_dataset.setItem(row, 0, QTableWidgetItem(f'{self._dist_along_grid_mid_pts[row]:.3f}'))
                self.ui.tbl_dataset.setItem(row, 1, QTableWidgetItem(f'{self._cur_sol_vals[row]:.3f}'))

    def _update_map(self):
        """Show/hide rows in the levee results table based on check status."""
        # add the map
        folium_map = folium.Map(tiles='')
        if self.ui.cbx_map.currentText() == 'ESRI Street Map':
            url = 'http://server.arcgisonline.com/ArcGIS/rest/services/World_Street_Map/MapServer/tile/{z}/{y}/{x}'
            tl = folium.raster_layers.TileLayer(tiles=url, name='ESRI Street Map', attr='ESRI')
            tl.add_to(folium_map)
        elif self.ui.cbx_map.currentText() == 'ESRI World Imagery':
            url = 'http://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}'
            tl = folium.raster_layers.TileLayer(tiles=url, name='ESRI World Imagery', attr='ESRI')
            tl.add_to(folium_map)

        # add the grid and tick lines
        for grid_line in self._grid_lines:
            grid = folium.PolyLine(locations=grid_line, color='black', weight=1, tooltip='1-D Grid')
            grid.add_to(folium_map)

        # draw a triangle at the end
        t = folium.RegularPolygonMarker(location=self._grid_lines[0][-1], number_of_sides=3,
                                        rotation=self._grid_angle - 90.0, color='black', line_color='black',
                                        fill_color='black')
        t.add_to(folium_map)

        draw_pts, dataset_extents = self._get_dataset_drawing_pts(self._cur_sol_vals)
        folium.PolyLine(locations=draw_pts, color='brown', weight=2).add_to(folium_map)

        extents = [[min(self._grid_line_extents[0][0], dataset_extents[0][0]),
                    min(self._grid_line_extents[0][1], dataset_extents[0][1])],
                   [max(self._grid_line_extents[1][0], dataset_extents[1][0]),
                    max(self._grid_line_extents[1][1], dataset_extents[1][1])]]

        folium_map.fit_bounds(extents)
        folium_map.save(self._url_file)
        self._viewer.setUrl(QUrl.fromLocalFile(self._url_file))
        self._viewer.show()

    def _get_grid_drawing_pts(self):
        """Show/hide rows in the levee results table based on check status."""
        pts = [(self._mapped_grid.data.info.attrs['x0'], self._mapped_grid.data.info.attrs['y0']),
               (self._mapped_grid.data.info.attrs['xend'], self._mapped_grid.data.info.attrs['yend'])]
        grid_dx = pts[1][0] - pts[0][0]
        grid_dy = pts[1][1] - pts[0][1]
        self._grid_angle = math.atan2(grid_dx, -grid_dy)

        grid_ls = LineString(pts)
        grid_dxs = self._mapped_grid.data.locations.dx.values

        self._dist_along_grid = [0]
        self._dist_along_grid_mid_pts = []
        for dx in grid_dxs:
            self._dist_along_grid_mid_pts.append(self._dist_along_grid[-1] + (dx / 2.0))
            self._dist_along_grid.append(self._dist_along_grid[-1] + dx)
        self._all_orig_grid_locs = [shapely_pt_to_loc(grid_ls.interpolate(dist)) for dist in self._dist_along_grid]
        self._grid_lines.append(self._reproject_points(self._all_orig_grid_locs))

        # tick marks
        ends = 0.010
        offset = ends * grid_ls.length
        right = grid_ls.parallel_offset(offset)
        left = grid_ls.parallel_offset(-offset)
        for p in pts:
            bpt = shPt(p)
            p0 = nearest_points(right, bpt)[0].coords[0]
            p1 = nearest_points(left, bpt)[0].coords[0]
            self._grid_lines.append(self._reproject_points([p0, p1]))

        ticks = [0.008, 0.005, 0.002]
        right_offsets = [grid_ls.parallel_offset(tick * grid_ls.length) for tick in ticks]
        left_offsets = [grid_ls.parallel_offset(-tick * grid_ls.length) for tick in ticks]

        all_pts = [shPt(p) for p in self._all_orig_grid_locs[1:-1]]
        for i, pt in enumerate(all_pts, 1):
            if i % 10 == 0:
                p0 = nearest_points(right_offsets[0], pt)[0].coords[0]
                p1 = nearest_points(left_offsets[0], pt)[0].coords[0]
            elif i % 5 == 0:
                p0 = nearest_points(right_offsets[1], pt)[0].coords[0]
                p1 = nearest_points(left_offsets[1], pt)[0].coords[0]
            else:
                p0 = nearest_points(right_offsets[2], pt)[0].coords[0]
                p1 = nearest_points(left_offsets[2], pt)[0].coords[0]
            self._grid_lines.append(self._reproject_points([p0, p1]))

        transformed_grid_ls = LineString(self._grid_lines[0])
        bounds = transformed_grid_ls.bounds
        self._grid_line_extents = [
            [bounds[0], bounds[1]],
            [bounds[2], bounds[3]],
        ]

    def _get_dataset_drawing_pts(self, vals):
        """Get drawing points for the dataset.

        Arguments:
            vals(list): Float values

        Returns:
            pts(list): Location of points to draw.
            extents(list): Extents of dataset
        """
        if self._scale_factors[self._sel_dataset] == 'q':
            factor = float(self.ui.edt_q_scale.text())
        else:
            factor = float(self.ui.edt_y_scale.text())
        dataset_locs = []
        for i, val in enumerate(vals):
            dataset_loc = [self._all_orig_grid_locs[i][0] + (factor * val * math.cos(self._grid_angle)),
                           self._all_orig_grid_locs[i][1] + (factor * val * math.sin(self._grid_angle))]
            dataset_locs.append(dataset_loc)

        latlon_vals = self._reproject_points(dataset_locs)
        latlon_ls = LineString(latlon_vals)
        bounds = latlon_ls.bounds
        latlon_extents = [
            [bounds[0], bounds[1]],  # (min_lat, min_lon)
            [bounds[2], bounds[3]],  # (max_lat, max_lon)
        ]

        return latlon_vals, latlon_extents
    """
    Qt overloads
    """
    def sizeHint(self):  # noqa: N802
        """Overridden method to help the dialog have a good minimum size.

        Returns:
            (:obj:`QSize`): Size to use for the initial dialog size.
        """
        return QSize(800, 500)

    def accept(self):
        """Save data from dialog on OK."""
        self._viewer.close()
        super().accept()


def meters_to_decimal_degrees(length_meters, latitude):
    """Convert meters to decimal degrees based on the latitude.

    Args:
        length_meters (:obj:`float`): length in meters
        latitude (:obj:`float`): latitude in decimal degrees

    Returns:
        (:obj:`float`): length in decimal degrees
    """
    return length_meters / _factor_from_latitude(latitude)


def feet_to_decimal_degrees(length_feet, latitude):
    """Convert feet to decimal degrees based on the latitude.

    Args:
        length_feet (:obj:`float`): length in feet
        latitude (:obj:`float`): latitude in decimal degrees

    Returns:
        (:obj:`float`): length in decimal degrees
    """
    length_meter = length_feet / FEET_PER_METER
    return meters_to_decimal_degrees(length_meter, latitude)


FEET_PER_METER = 3.28083989501


def _factor_from_latitude(latitude):
    """Computes meters, decimal degrees conversion factor from latitude.

    Args:
        latitude (:obj:`float`): the latitude

    Returns:
        (:obj:`float`): conversion factor
    """
    return 111.32 * 1000 * math.cos(latitude * (math.pi / 180))


def shapely_pt_to_loc(shp):
    """Calculates the dot product between 2D vectors.

    Args:
        shp (shapely_pt): shapely point

    Returns:
        (list): x, y, z location
    """
    return [shp.x, shp.y, 0.0]
