"""Data for populating a time series based on inputs."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import math

# 2. Third party modules

# 3. Aquaveo modules
from xms.api._xmsapi.dmi import DatasetItem
from xms.api.tree import tree_util
from xms.api.tree.tree_node import TreeNode
from xms.constraint import read_grid_from_file
from xms.coverage.spatial import SpatialCoverage
from xms.data_objects.parameters import FilterLocation
from xms.extractor import UGrid2dDataExtractor
from xms.guipy.dialogs.dataset_selector import DatasetSelector
from xms.guipy.time_format import XmsTimeFormatter
from xms.snap.snap_exterior_arc import SnapExteriorArc

# 4. Local modules


class PopulateFlowData:
    """Data for specifying a flow based on datasets inputs."""
    def __init__(self):
        """Initializes the data."""
        self._geometry_uuid = ''
        self._depth_uuid = ''
        self._velocity_uuid = ''
        self.pe_tree = None
        self._scalar_pe_tree = None
        self._vector_pe_tree = None
        self._query = None
        self._dset_dumps = {}  # Cache of datasets that have been clicked on
        self._co_grids = {}
        self._spatial_data_covs = {}
        self._times = []
        self._depth_times = []
        self._velocity_times = []
        self._depth_start_idx = 0
        self._vel_start_idx = 0
        self._depth_end_idx = -1
        self._vel_end_idx = -1
        self._time_settings = None
        self._arc = None
        self._arc_center = None
        self._arc_id = 0
        self._x_plot = []
        self._y_plot = []
        self._arc_locations = []
        self._non_overlapping_times = False
        self._populate_method = 'Flow'
        self._hydrographs = []
        self.current_hydrograph = 0
        self.population_methods = ['Flow', 'Hydrograph']
        self._spatial_data_uuid = None

    def set_query(self, query):
        """Set the query and context objects.

        Args:
            query (xmsdmi.dmi.Query): Query for communicating with XMS.
        """
        self._query = query

    def set_arc_id(self, arc_id):
        """Set the arc id for the cross-section.

        Args:
            arc_id (int): The id of the arc from SMS.
        """
        self._arc_id = arc_id

    @property
    def geometry_uuid(self):
        """Property for the geometry UUID."""
        return self._geometry_uuid

    @geometry_uuid.setter
    def geometry_uuid(self, geometry_uuid):
        """Setter for the geometry UUID.

        Args:
            geometry_uuid (str): The UUID of the geometry.
        """
        if self._geometry_uuid == geometry_uuid:
            return
        self._geometry_uuid = geometry_uuid
        geometry_node = tree_util.find_tree_node_by_uuid(self.pe_tree, self._geometry_uuid)
        scalar_tree = TreeNode(other=geometry_node)  # Create a copy of the project explorer tree to filter.
        vector_tree = TreeNode(other=geometry_node)  # Create a copy of the project explorer tree to filter.
        tree_util.filter_project_explorer(scalar_tree, DatasetSelector.is_scalar_if_dset)
        tree_util.filter_project_explorer(vector_tree, DatasetSelector.is_vector_if_dset)
        self._scalar_pe_tree = tree_util.trim_tree_to_items_of_type(scalar_tree, [DatasetItem])
        self._vector_pe_tree = tree_util.trim_tree_to_items_of_type(geometry_node, [DatasetItem])
        self._build_cross_section()
        self._depth_uuid = ''
        self._velocity_uuid = ''
        self._non_overlapping_times = False

    @property
    def spatial_data_uuid(self):
        """Property for the spatial data coverage UUID."""
        return self._spatial_data_uuid

    @spatial_data_uuid.setter
    def spatial_data_uuid(self, spatial_data_uuid):
        """Setter for the spatial data coverage UUID.

        Args:
            spatial_data_uuid (str): The UUID of the spatial data coverage.
        """
        if self._spatial_data_uuid == spatial_data_uuid:
            return
        self._spatial_data_uuid = spatial_data_uuid
        if self._spatial_data_uuid is None:
            self._hydrographs = []
            return
        # get the points and store the names
        self._query_for_spatial_data(spatial_data_uuid)
        arc_center = self._get_arc_center()
        if arc_center is not None:
            self._hydrographs = []
            # get all of the hydrographs and sort them by their distance to the arc center point
            spatial_points = self._spatial_data_covs[self._spatial_data_uuid].spatial_data
            points = self._spatial_data_covs[self._spatial_data_uuid].coverage.get_points(FilterLocation.LOC_ALL)
            for point in points:
                point_id = point.id
                if point_id not in spatial_points:
                    continue
                for plugin_data in spatial_points[point_id].values():
                    if isinstance(plugin_data, SpatialCoverage.TimeSeries):
                        if plugin_data.group_name == 'XY Series':
                            # figure out the distance between this point and the arc center
                            dist = self.get_distance(arc_center[0:2], (point.x, point.y))
                            self._hydrographs.append((plugin_data.curve_name, dist, plugin_data.series))
                            break
            self._hydrographs.sort(key=lambda x: x[1])

    def _build_cross_section(self):
        """Builds the cross section to plot."""
        self._query_for_geometry(self._geometry_uuid)
        arc = self._get_arc()
        if arc is not None:
            snap_ex_arc = SnapExteriorArc()
            snap_ex_arc.set_grid(grid=self._co_grids[self._geometry_uuid], target_cells=False)
            result = snap_ex_arc.get_snapped_points(arc)
            # Get the total distance in XY.
            # Get the distances in XY for each point.
            distances = []
            for idx, loc in enumerate(result['location']):
                distances.append(0.0 if idx == 0 else self._distance(result['location'][idx - 1], loc))
            total_distance = sum(distances)
            self._x_plot = []
            self._y_plot = []
            self._arc_locations = result['location']
            previous_distance = 0.0
            for location, distance in zip(result['location'], distances):
                normal_distance = (distance + previous_distance) / total_distance
                previous_distance += distance
                self._x_plot.append(normal_distance)
                self._y_plot.append(location[2])  # Get the Z-coordinate

    @staticmethod
    def _distance(pt1, pt2):
        """Calculates the 2D distance between points.

        Args:
            pt1 (tuple): The first point in x, y.
            pt2 (tuple): The second point in x, y.
        """
        return math.sqrt(math.pow(pt2[0] - pt1[0], 2) + math.pow(pt2[1] - pt1[1], 2))

    @property
    def depth_uuid(self):
        """Property for the depth dataset UUID."""
        return self._depth_uuid

    @depth_uuid.setter
    def depth_uuid(self, depth_uuid):
        """Setter for the depth dataset UUID.

        Args:
            depth_uuid (str): The UUID of the depth dataset.
        """
        self._depth_uuid = depth_uuid
        self._get_timesteps()

    @property
    def velocity_uuid(self):
        """Property for the velocity dataset UUID."""
        return self._velocity_uuid

    @velocity_uuid.setter
    def velocity_uuid(self, velocity_uuid):
        """Setter for the velocity dataset UUID.

        Args:
            velocity_uuid (str): The UUID of the velocity dataset.
        """
        self._velocity_uuid = velocity_uuid
        self._get_timesteps()

    @property
    def scalar_pe_tree(self):
        """Property for the project explorer tree trimmed to scalar datasets of the geometry."""
        return self._scalar_pe_tree

    @property
    def vector_pe_tree(self):
        """Property for the project explorer tree trimmed to vector datasets of the geometry."""
        return self._vector_pe_tree

    @property
    def depth_name(self):
        """Property for the name and path of the depth dataset relative to the geometry."""
        if self._depth_uuid:
            path = tree_util.build_tree_path(self._scalar_pe_tree, self._depth_uuid)
            if path:
                return path
        return '(none selected)'

    @property
    def velocity_name(self):
        """Property for the name and path of the velocity dataset relative to the geometry."""
        if self._velocity_uuid:
            path = tree_util.build_tree_path(self._vector_pe_tree, self._velocity_uuid)
            if path:
                return path
        return '(none selected)'

    @property
    def geometry_name(self):
        """Property for the name and path of the geometry."""
        if self._geometry_uuid:
            path = tree_util.build_tree_path(self.pe_tree, self._geometry_uuid)
            if path:
                return path
        return '(none selected)'

    @property
    def spatial_data_name(self):
        """Property for the name and path of the spatial data coverage."""
        if self._spatial_data_uuid:
            path = tree_util.build_tree_path(self.pe_tree, self._spatial_data_uuid)
            if path:
                return path
        return '(none selected)'

    @property
    def populate_method(self):
        """Property for the population method."""
        return self._populate_method

    @populate_method.setter
    def populate_method(self, populate_method):
        """Setter for the population method.

        Args:
            populate_method (str): The population method.
        """
        if populate_method not in self.population_methods:
            return
        self._populate_method = populate_method

    def _get_timesteps(self):
        """Get the timesteps of the datasets and compare."""
        self._non_overlapping_times = False
        if not self._depth_uuid or not self._velocity_uuid:
            return
        self._query_for_time_settings()  # Get the XMS global time settings if we haven't already
        depth_dataset = self._get_dataset(self._depth_uuid)
        self._depth_times = self._times_from_dataset(depth_dataset)
        velocity_dataset = self._get_dataset(self._velocity_uuid)
        self._velocity_times = self._times_from_dataset(velocity_dataset)
        # compare the datasets
        if self._depth_times[-1] < self._velocity_times[0] or self._velocity_times[-1] < self._depth_times[0]:
            self._non_overlapping_times = True
            return
        # merge the timelists so there is one list to interpolate from
        self._depth_start_idx = 0
        self._vel_start_idx = 0
        self._depth_end_idx = -1
        self._vel_end_idx = -1
        if self._depth_times[0] != self._velocity_times[0]:
            depth_too_early = self._depth_times[0] < self._velocity_times[0]
            pre_times = self._depth_times if depth_too_early else self._velocity_times
            first_time = self._velocity_times[0] if depth_too_early else self._depth_times[0]
            idx = 0
            while idx < len(pre_times):
                if pre_times[idx] >= first_time:  # Find the first time that is the same or after the first usable time.
                    break
                idx += 1
            if depth_too_early:
                self._depth_start_idx = idx
            else:
                self._vel_start_idx = idx

        if self._depth_times[-1] != self._velocity_times[-1]:
            depth_too_late = self._depth_times[-1] > self._velocity_times[-1]
            post_times = self._depth_times if depth_too_late else self._velocity_times
            last_time = self._velocity_times[-1] if depth_too_late else self._depth_times[-1]
            idx = -1
            lowest_idx = -1 * len(post_times)
            while idx >= lowest_idx:
                if post_times[idx] <= last_time:  # Find the last time that is the same or before the last usable time.
                    break
                idx -= 1
            if depth_too_late:
                self._depth_end_idx = idx
            else:
                self._vel_end_idx = idx
        self._times = self._depth_times[self._depth_start_idx:self._depth_end_idx]
        for vel_time in self._velocity_times[self._vel_start_idx:self._vel_end_idx]:
            if vel_time not in self._times:
                self._times.append(vel_time)
        self._times.sort()

    def _get_dataset(self, dset_uuid):
        """If we don't already have this dataset, uses the query to find it and cache it.

        Args:
            dset_uuid (str): The uuid of the selected item.

        Returns:
            The dataset with the given uuid.
        """
        if dset_uuid not in self._dset_dumps:
            dset = self._query_for_dataset(dset_uuid)
            self._dset_dumps[dset_uuid] = dset

        return self._dset_dumps[dset_uuid]

    def _query_for_dataset(self, dset_uuid):
        """Uses the query and the dataset uuid to find and return the dataset.

        Args:
            dset_uuid (str): The uuid of the selected item.

        Returns:
            The dataset with the given uuid.
        """
        return self._query.item_with_uuid(dset_uuid)

    def _query_for_time_settings(self):
        """Uses the query to get the current XMS global time settings."""
        if self._time_settings is None:
            self._time_settings = XmsTimeFormatter(self._query.global_time_settings)

    def _times_from_dataset(self, dataset):
        """Gets the list of time steps times from the dataset.

        Args:
            dataset (xms.data_objects.parameters.Dataset): The xms.data_objects dataset dump

        Returns:
            A list of time step times as datatime.timedelta.
        """
        if not dataset:
            return []

        ts_times = []
        using_reftime = dataset.ref_time is not None
        timestep_count = dataset.num_times
        for ts_idx in range(timestep_count):
            if not using_reftime:  # ts_delta is relative from the XMS global time
                ts_time = dataset.timestep_offset(ts_idx)
            else:
                ts_time_absolute = dataset.ref_time + dataset.timestep_offset(ts_idx)
                ts_time = ts_time_absolute - self._time_settings.zero_time
            ts_times.append(ts_time)

        return ts_times

    def _query_for_geometry(self, geometry_uuid):
        """Uses the query and the geometry uuid to find the geometry.

        Args:
            geometry_uuid (str): The uuid of the selected item.
        """
        if geometry_uuid in self._co_grids:
            return
        if not geometry_uuid:
            return
        mesh = self._query.item_with_uuid(geometry_uuid)
        co_grid = None
        if mesh:
            file = mesh.cogrid_file
            if file:
                co_grid = read_grid_from_file(file)
        self._co_grids[geometry_uuid] = co_grid

    def _query_for_spatial_data(self, spatial_data_uuid):
        """Uses the query and the spatial data coverage uuid to find the coverage.

        Args:
            spatial_data_uuid (str): The uuid of the selected item.
        """
        if spatial_data_uuid in self._spatial_data_covs:
            return
        if not spatial_data_uuid:
            return

        cov = self._query.item_with_uuid(spatial_data_uuid, generic_coverage=True)
        if cov:
            self._spatial_data_covs[spatial_data_uuid] = cov

    def _get_arc(self):
        """Get the arc. If not found locally, query for it."""
        if self._arc is None:
            cov_item = self._query.parent_item()
            if cov_item:
                all_arcs = cov_item.arcs
                for arc in all_arcs:
                    if arc.id == self._arc_id:
                        self._arc = arc
                        break
        return self._arc

    def _get_arc_center(self):
        """Get the arc center. If not found locally, calculate it."""
        if self._arc_center is None:
            arc = self._get_arc()
            arc_pts = arc.get_points(FilterLocation.LOC_ALL)
            distances = []
            # change data_object Point to python iterables
            prev_point = ()
            for i, point in enumerate(arc_pts):
                next_point = (point.x, point.y, point.z)
                if i == 0:
                    distances.append(0.0)
                else:
                    distances.append(self.get_distance(prev_point, next_point))
                prev_point = next_point
            total_distance = math.fsum(distances)
            half_distance = total_distance / 2.0
            dist = 0.0
            for i in range(len(distances)):
                dist += distances[i]
                if dist == half_distance:
                    self._arc_center = (arc_pts[i].x, arc_pts[i].y, arc_pts[i].z)
                    break
                elif dist > half_distance:
                    start_t = distances[i - 1] / total_distance
                    end_t = dist / total_distance
                    x_diff = arc_pts[i].x - arc_pts[i - 1].x
                    y_diff = arc_pts[i].y - arc_pts[i - 1].y
                    z_diff = arc_pts[i].z - arc_pts[i - 1].z
                    percent = (0.5 - start_t) / (end_t - start_t)
                    self._arc_center = (
                        arc_pts[i - 1].x + (x_diff * percent), arc_pts[i - 1].y + (y_diff * percent),
                        arc_pts[i - 1].z + (z_diff * percent)
                    )
                    break
        return self._arc_center

    def add_series(self, axes, distance_header, elevation_header):
        """Adds the XY line series to the plot.

        Args:
            axes (AxesSubplot): The plot axes.
            distance_header (str): The x-axis label
            elevation_header (str): The y-axis label
        """
        plot_name = "Cross Section"

        self._add_cross_section_series(axes, plot_name, distance_header, elevation_header, True)

    def _add_cross_section_series(self, axes, name, station_header, elevation_header, show_axis_titles):
        """Adds an XY line series to the plot.

        Args:
            axes (AxesSubplot): The plot axes.
            name (str): The series name.
            station_header (str): The x-axis label
            elevation_header (str): The y-axis label
            show_axis_titles (bool): If True, axis titles displayed.
        """
        axes.clear()
        axes.set_title(name)
        axes.grid(True)
        axes.ticklabel_format(useOffset=False)

        # Add data to plot
        x_column, y_column = self.get_cross_section_curve_values()
        axes.plot(x_column, y_column, 'k', label="Cross section")

        # Axis titles
        if show_axis_titles:
            axes.set_xlabel(station_header)
        if show_axis_titles:
            axes.set_ylabel(elevation_header)

    def get_cross_section_curve_values(self):
        """Returns a tuple of the cross section curve x and y values."""
        return self._x_plot, self._y_plot

    def get_messages(self):
        """Get the messages for the user.

        Returns:
            A list of messages to tell the user.
        """
        messages = []
        if self._populate_method == 'Flow':
            if self._non_overlapping_times:
                messages.append('Warning: The depth and velocity dataset times do not overlap.')
            if not self._geometry_uuid:
                messages.append('Select a geometry.')
            if not self._depth_uuid:
                messages.append('Select a scalar depth dataset with times.')
            if not self._velocity_uuid:
                messages.append('Select a vector velocity dataset with times.')
        elif self._populate_method == 'Hydrograph':
            if not self._spatial_data_uuid:
                messages.append('Select a spatial data coverage.')
        return messages

    def get_time_series(self):
        """Get the time series for the options.

        Returns:
            A tuple of lists of float. The first list is a list of times in seconds, the second is a list of values.
        """
        if self._populate_method == 'Flow':
            return self._calculate_flow()
        elif self._populate_method == 'Hydrograph':
            series = self._hydrographs[self.current_hydrograph][2]
            return series[0], series[1]
        else:
            return [], []

    def _calculate_flow(self):
        """Calculate the flow for each time.

        Returns:
            A tuple of lists of float. The first list is a list of times in seconds, the second is a list of flows.
        """
        flows = []
        # Interpolate the depth and velocity dataset values to the timesteps.
        extractor = UGrid2dDataExtractor(self._co_grids[self._geometry_uuid].ugrid)
        extractor.extract_locations = self._arc_locations
        depth_dset = self._dset_dumps[self._depth_uuid]
        velocity_dset = self._dset_dumps[self._velocity_uuid]

        depth_values = self._get_depths_at_times(depth_dset)
        vel_x_values, vel_y_values = self._get_velocities_at_times(velocity_dset)

        seg_lengths = []
        seg_nxs = []
        seg_nys = []
        for idx in range(1, len(self._arc_locations)):
            dx = self._arc_locations[idx][0] - self._arc_locations[idx - 1][0]
            dy = self._arc_locations[idx][1] - self._arc_locations[idx - 1][1]
            seg_len = math.sqrt(dx**2 + dy**2)
            if seg_len == 0.0:
                seg_lengths.append(None)
                seg_nxs.append(None)
                seg_nys.append(None)
                continue
            seg_lengths.append(seg_len)
            seg_nx = (dy * -1) / seg_len
            seg_ny = dx / seg_len
            seg_nxs.append(seg_nx)
            seg_nys.append(seg_ny)

        time_value = -1
        for depth_at_time, vx_at_time, vy_at_time in zip(depth_values, vel_x_values, vel_y_values):
            flux = 0.0
            time_value += 1
            extractor.set_grid_point_scalars(depth_at_time, [], depth_dset.location)
            depths = extractor.extract_data()
            extractor.set_grid_point_scalars(vx_at_time, [], velocity_dset.location)
            vx = extractor.extract_data()
            extractor.set_grid_point_scalars(vy_at_time, [], velocity_dset.location)
            vy = extractor.extract_data()
            for idx in range(len(depths) - 1):
                # Make sure that all values are active for the calculation.
                # If something is inactive, skip the segment.
                if seg_nxs[idx] is None:
                    continue
                if vx[idx] is None or vx[idx + 1] is None or \
                        vx[idx] == float('nan') or vx[idx + 1] == float('nan'):
                    continue
                if depths[idx] is None or depths[idx + 1] is None or \
                        depths[idx] == float('nan') or depths[idx + 1] == float('nan'):
                    continue
                v_mag1 = seg_nxs[idx] * vx[idx] + seg_nys[idx] * vy[idx]
                v_mag2 = seg_nxs[idx] * vx[idx + 1] + seg_nys[idx] * vy[idx + 1]
                d1 = depths[idx]
                d2 = depths[idx + 1]

                if d1 >= 0.0 and d2 >= 0.0:
                    flux += ((v_mag1 * d1 + v_mag2 * d2) * 0.5) * seg_lengths[idx]
                elif d1 >= 0.0:
                    percent_wet = d1 / (d1 - d2)
                    flux += d1 * v_mag1 * percent_wet * 0.5
                elif d2 >= 0.0:
                    percent_wet = d2 / (d2 - d1)
                    flux += d2 * v_mag2 * percent_wet * 0.5
            flows.append(math.fabs(flux))
        return [flow_time.seconds for flow_time in self._times], flows

    def _get_depths_at_times(self, depth_dset):
        """Gets the depth values interpolated for each timestep.

        Args:
            depth_dset (xms.data_objects.parameters.Dataset): The depth dataset

        Returns:
            A list of time interpolated values.
        """
        depth_values = []
        # Loop through the timesteps
        depth_time_idx = self._depth_start_idx
        if self._depth_start_idx > 0:
            previous_time_idx = self._depth_start_idx - 1
            previous_depth_values = depth_dset.values[previous_time_idx]
        else:
            previous_time_idx = 0
            previous_depth_values = []
        next_depth_values = depth_dset.values[depth_time_idx]
        for time in self._times:
            if self._depth_times[depth_time_idx] == time:
                previous_depth_values = next_depth_values
                depth_values.append(next_depth_values)
                depth_time_idx += 1
                next_depth_values = depth_dset.values[depth_time_idx]
            else:
                total_time_diff = self._depth_times[depth_time_idx] - self._depth_times[previous_time_idx]
                percent = (time - self._depth_times[previous_time_idx]) / total_time_diff
                prev_percent = 1.0 - percent
                values = [(p * prev_percent) + (d * percent) for p, d in zip(previous_depth_values, next_depth_values)]
                depth_values.append(values)
        return depth_values

    def _get_velocities_at_times(self, velocity_dset):
        """Gets the velocity values interpolated for each timestep.

        Args:
            velocity_dset (xms.data_objects.parameters.Dataset): The velocity dataset

        Returns:
            A tuple of lists of time interpolated values.
        """
        vel_x_values = []
        vel_y_values = []
        # Loop through the timesteps
        vel_time_idx = self._vel_start_idx
        if self._vel_start_idx > 0:
            previous_time_idx = self._vel_start_idx - 1
            prev_vel_values = velocity_dset.values[previous_time_idx]
        else:
            previous_time_idx = 0
            prev_vel_values = []
        next_vel_values = velocity_dset.values[vel_time_idx]
        for time in self._times:
            if self._velocity_times[vel_time_idx] == time:
                prev_vel_values = next_vel_values
                vel_x_values.append([vel[0] for vel in next_vel_values])
                vel_y_values.append([vel[1] for vel in next_vel_values])
                vel_time_idx += 1
                next_vel_values = velocity_dset.values[vel_time_idx]
            else:
                total_time_diff = self._velocity_times[vel_time_idx] - self._velocity_times[previous_time_idx]
                percent = (time - self._velocity_times[previous_time_idx]) / total_time_diff
                prev_percent = 1.0 - percent
                x_values = []
                y_values = []
                for p, d in zip(prev_vel_values, next_vel_values):
                    x_values.append((p[0] * prev_percent) + (d[0] * percent))
                    y_values.append((p[1] * prev_percent) + (d[1] * percent))
                vel_x_values.append(x_values)
                vel_y_values.append(y_values)
        return vel_x_values, vel_y_values

    def get_hydrograph_names(self):
        """Gets the names of the hydrographs from the spatial data coverage.

        Returns:
            (:obj:`list` of str): The names of the curves.
        """
        return [curve_row[0] for curve_row in self._hydrographs]

    @staticmethod
    def get_distance(p1, p2):
        """Get the distance between points.

        Args:
            p1 (iterable): The first point.
            p2 (iterable): The second point.

        Returns:
            (float): The distance between points.
        """
        return math.sqrt(math.fsum((p1x - p2x)**2.0 for p1x, p2x in zip(p1, p2)))
