"""Class for managing data for profile and cross-section plots."""
__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import math
import os

# 2. Third party modules
import numpy as np
from shapely.geometry import LineString, Point as ShPoint

# 3. Aquaveo modules
from xms.constraint.ugrid_activity import active_points_from_cells
from xms.extractor.ugrid_2d_polyline_data_extractor import UGrid2dPolylineDataExtractor
from xms.grid.ugrid.ugrid_utils import read_ugrid_from_ascii_file

# 4. Local modules


class ExtractedSimData:
    """Class for managing data for profile and cross-section plots."""
    def __init__(self, plot_sim_data, structure_comp, is_cross_section=False):
        """Initialize the class.

        Args:
            plot_sim_data (:obj:`PlotSimData`): plot simulation data
            structure_comp (:obj:`StructureComponent`):: structure component
            is_cross_section (bool): True if this is a cross-section plot
        """
        self._plot_sim_data = plot_sim_data
        self._comp = structure_comp
        self._is_cross_section = is_cross_section
        self._comp_main_file = '' if not self._comp else self._comp.main_file
        self._is_culvert = False
        if self._comp and self._comp.data.data_dict['structure_type'] == 'Culvert':
            self._is_culvert = True
        self._polyline = []
        self.selected_time = {sim: plot_sim_data.sim_data[sim]['times'][-1] for sim in plot_sim_data.sim_names}
        self.plot_data = {}
        self.structure_polygon = []
        self.structure_piers = []
        self._co_grid = None
        self._ug = None
        self._sim = ''
        self._ls = None
        self._structure_top_ls = None
        self._structure_bot_ls = None
        self._structure_max_elev = None

    def set_extraction_polyline(self, polyline):
        """Set the polyline.

        Args:
            polyline (list): polyline
        """
        self._polyline = polyline
        self.plot_data = {}
        self.structure_polygon = []
        self.structure_piers = []
        self._get_structure_polygon()

    def _calc_structure_piers(self):
        """Calculate the structure piers."""
        if self._is_culvert:
            return
        if len(self.structure_piers) > 0:
            return
        if self._structure_top_ls is None:
            return
        # piers only exist if there are missing elevations
        elevs = self.plot_data[self._sim]['elevations'][1]
        if not np.any(np.isnan(elevs)):
            return
        dists = self.plot_data[self._sim]['elevations'][0]
        pier_locs = []
        for idx in range(len(dists)):
            if math.isnan(elevs[idx]):
                if idx > 0 and idx + 1 < len(dists):
                    pier_locs.append((dists[idx - 1], dists[idx + 1]))

        if pier_locs:
            min_elev = min(elevs)
            top_ls = self._structure_top_ls
            bot_ls = self._structure_bot_ls
            for loc in pier_locs:
                loc_0_y = 0.5 * (top_ls.interpolate(loc[0]).y + bot_ls.interpolate(loc[0]).y)
                loc_1_y = 0.5 * (top_ls.interpolate(loc[1]).y + bot_ls.interpolate(loc[1]).y)
                pier = [(loc[0], min_elev), (loc[1], min_elev), (loc[1], loc_1_y), (loc[0], loc_0_y)]
                self.structure_piers.append(pier)

    def _extract_elevations(self):
        """Extract elevations from the current simulation."""
        sim = self._sim
        grid = self._co_grid
        ls = self._ls
        ug = self._ug
        elevs = [p[2] for p in ug.locations]
        if grid.uuid not in self._plot_sim_data.grid_pt_data_extractor:
            extractor = UGrid2dPolylineDataExtractor(ugrid=ug, scalar_location='points')
            self._plot_sim_data.grid_pt_data_extractor[grid.uuid] = extractor
        extractor = self._plot_sim_data.grid_pt_data_extractor[grid.uuid]
        activity = [1] * len(elevs)
        extractor.set_grid_scalars(elevs, activity, 'points')
        extractor.set_polyline(self._polyline)
        locs = extractor.extract_locations
        scalars = list(extractor.extract_data())
        dists = [ls.project(ShPoint(p)) for p in locs]
        self.plot_data[sim] = {
            'elevations': (dists, scalars),
        }
        self._calc_structure_piers()
        self._update_culvert_structure_polygon()

    def _extract_dataset(self, dataset):
        """Extract dataset data from the simulation."""
        sim = self._sim
        grid = self._co_grid
        ug = self._ug
        ls = self._ls
        ds = self._plot_sim_data.sim_data[sim].get(dataset, None)
        if ds is None:
            return
        extractor = self._plot_sim_data.grid_pt_data_extractor[grid.uuid]
        if ds.location == 'cells':
            if grid.uuid not in self._plot_sim_data.grid_cell_data_extractor:
                extractor = UGrid2dPolylineDataExtractor(ugrid=ug, scalar_location='cells')
                self._plot_sim_data.grid_cell_data_extractor[grid.uuid] = extractor
            extractor = self._plot_sim_data.grid_cell_data_extractor[grid.uuid]
            extractor.set_polyline(self._polyline)
        ds_data = self.plot_data[sim].get(f'{dataset}_times', {})
        time_str = self.selected_time[sim]
        if time_str not in ds_data:
            idx = self._plot_sim_data.sim_data[sim]['times'].index(time_str)
            vals = ds.values[idx]
            if ds.activity is None:
                activity = vals != ds.null_value
            else:
                activity = ds.activity[idx]
                if len(activity) != len(vals):
                    activity = active_points_from_cells(ug, activity)
            extractor.set_grid_scalars(vals, activity, ds.location)
            locs = extractor.extract_locations
            scalars = list(extractor.extract_data())
            dists = [ls.project(ShPoint(p)) for p in locs]
            ds_data[time_str] = (dists, scalars)
            self.plot_data[sim][f'{dataset}_times'] = ds_data
        self.plot_data[sim][dataset] = ds_data[time_str]

    def extract_data_from_sim(self, sim):
        """Extract elevations and wse data from the simulation.

        Args:
            sim (str): simulation name
        """
        self._sim = ''
        self._co_grid = self._ug = self._ls = None
        if sim not in self._plot_sim_data.sim_data:
            return
        grid = self._plot_sim_data.sim_data[sim]['grid']
        if grid is None or not self._polyline:
            return
        self._co_grid = grid
        self._ug = grid.ugrid
        self._sim = sim
        self._ls = LineString(self._polyline)
        self._extract_elevations()
        self._extract_dataset('wse')
        self._extract_dataset('velocity_mag')

    def _get_structure_polygon(self):
        """Get the polygon for the structure where it is intersected by the profile line."""
        self.structure_polygon = []
        self._structure_top_ls = self._structure_bot_ls = None
        self._structure_max_elev = None
        if not self._polyline:
            return
        comp_dir = os.path.dirname(self._comp_main_file)
        top_file = os.path.join(comp_dir, 'top2d.xmugrid')
        top_dists, top_scalars = self._extract_structure_elevations(top_file)
        bot_file = os.path.join(comp_dir, 'bottom2d.xmugrid')
        bot_dists, bot_scalars = self._extract_structure_elevations(bot_file)
        if not top_dists or not bot_dists:
            return
        # make a line string for the top and bottom of the structure
        self._structure_max_elev = max(top_scalars)
        xy = [(td, ts) for td, ts in zip(top_dists, top_scalars)]
        self._structure_top_ls = LineString(xy)
        xy = [(bd, bs) for bd, bs in zip(bot_dists, bot_scalars)]
        self._structure_bot_ls = LineString(xy)

        self.structure_polygon = [(p[0], p[1]) for p in zip(top_dists, top_scalars)]
        bot_scalars.reverse()
        bot_dists.reverse()
        self.structure_polygon.extend([(p[0], p[1]) for p in zip(bot_dists, bot_scalars)])

    def _extract_structure_elevations(self, ug_file):
        """Extract the elevations from the UGrid."""
        rval = [], []
        try:
            ug = read_ugrid_from_ascii_file(ug_file)
        except Exception:
            return rval

        ls = LineString(self._polyline)
        elevs = [p[2] for p in ug.locations]
        extractor = UGrid2dPolylineDataExtractor(ugrid=ug, scalar_location='points')
        extractor.set_polyline(self._polyline)
        activity = [1] * len(elevs)
        extractor.set_grid_scalars(elevs, activity, 'points')
        dists = [ls.project(ShPoint(p)) for p in extractor.extract_locations]
        scalars = list(extractor.extract_data())
        for i, s in enumerate(scalars):
            if not math.isnan(s):
                rval[0].append(dists[i])
                rval[1].append(s)
        return rval

    def _update_culvert_structure_polygon(self):
        """Make sure the min elevation of the culvert structure polygon is the same as the min ground elevation."""
        if not self._is_culvert or not self._is_cross_section or not self.structure_polygon:
            return
        elevs = self.plot_data[self._sim]['elevations'][1]
        min_elev = min(elevs)
        struct_elevs = [p[1] for p in self.structure_polygon]
        min_struct_elev = min(struct_elevs)
        if min_elev < min_struct_elev:
            for i, p in enumerate(self.structure_polygon):
                if p[1] == min_struct_elev:
                    self.structure_polygon[i] = (p[0], min_elev)
