"""Module for performing SRH-2D solution post-processing."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import copy
import math
import os
import sys
import uuid

# 2. Third party modules
import h5py
import numpy as np

# 3. Aquaveo modules
from xms.core.filesystem import filesystem
from xms.datasets.dataset_reader import DatasetReader
from xms.datasets.dataset_writer import DatasetWriter
from xms.grid.geometry import geometry as gm

# 4. Local modules
from xms.srh.file_io.geom_reader import GeomReader

# constants for newly created datasets
PT_DSET = 0
CELL_DSET = 1
TS_VALUES = 2
TS_VALUES_IDW = 3


class TimestepValues:
    """Container for managing scalar point dataset computations for a given timestep."""
    def __init__(self, cell_value_data, null_value, num_pts):
        """Initialize the container.

        Args:
            cell_value_data (:obj:`numpy.ndarray[float]`): List of the source cell dataset values that will be
                interpolated to the target points.
            null_value (:obj:`float`): Null value that the datasets will be initialized to. Indicates activity - will be
                inactive if the value is never set
            num_pts (:obj:`int`): Number of target points

        """
        self.cell_value_data = cell_value_data  # source cell data for a given timestep
        # target value for each point at a given timestep (idw sum)
        self.point_value_array = np.array([null_value for _ in range(num_pts)])
        self.active_sum = 0.0

    def reset(self):
        """Set the current value back to an initial state."""
        self.active_sum = 0.0

    def append_weighted_value(self, weight, cell_index):
        """Append a dataset value from a source cell that is wet and neighboring the current point being processed.

        The cell dataset value should be weighted by its distance from the target point.

        Args:
            weight (:obj:`float`): The weight to apply to this cell index's value. (Inverse distance used for srh_post)
            cell_index (:obj:`int`): Index in the cell array of this cell. (Zero-based id)

        Returns:
            (:obj:`float`): The cells value (without the weight applied)
        """
        value = self.cell_value_data[cell_index]
        self.active_sum += value * weight
        return value

    def interpolate_weighted_neighbors(self, weighted_sum):
        """Interpolate the weighted values of all the cells that have been appended for the current point.

        The active neighbor cell values will be cleared after calling this method.

        Args:
            weighted_sum (:obj:`float`): The sum of all the weights for all the values that have been appended for the
                current point

        Returns:
            (:obj:`float`): The sum of all the weighted values that have been added for this point divided
            by weighted_sum.
        """
        return self.active_sum / weighted_sum

    def set_value(self, point_index, value):
        """Manually updates a point's dataset value and updates the timestep min/max.

        Args:
            point_index (:obj:`int`): Index in the point array of this point. (Zero-based id)
            value (:obj:`float`): The new dataset value to assign to this point
        """
        self.point_value_array[point_index] = value

    def set_timestep_dataset(self, point_value_dset, ts_time):
        """Writes interpolated point dataset timestep to an H5 file on disk.

        Args:
            point_value_dset: File handle pointing to the H5 dataset location where the target point dataset timestep
                will be written.
            ts_time (:obj:`float`): The timestep offset time
        """
        point_value_dset.append_timestep(time=ts_time, data=self.point_value_array)


class VectorTimestepValues:
    """Container for managing scalar point dataset computations for a given timestep."""
    def __init__(self, cell_value_data, null_value, num_pts):
        """Initialize the container.

        Args:
            cell_value_data (:obj:`numpy.ndarray[list[float]]`): List of the source cell dataset values
                that will be interpolated to the target points - [[Vx, Vy]].
            null_value (:obj:`float`): Null value that the datasets will be initialized to. Indicates activity - will be
                inactive if the value is never set
            num_pts (:obj:`int`): Number of target points
        """
        # idw weighted values of neighboring cells for a given timestep
        self.active_neighbor_cell_xs = []
        self.active_neighbor_cell_ys = []
        self.cell_value_data = cell_value_data  # source cell data for a given timestep
        # target value for each point at a given timestep (idw sum)
        self.point_value_array = np.array([[null_value, null_value] for _ in range(num_pts)])
        self.point_value_mag_array = np.array([null_value for _ in range(num_pts)])

    def append_weighted_value(self, weight, cell_index):
        """Append a dataset value from a source cell that is wet and neighboring the current point being processed.

        The cell dataset value should be weighted by its distance from the target point.

        Args:
            weight (:obj:`float`): The weight to apply to this cell index's value. (Inverse distance used for srh_post)
            cell_index (:obj:`int`): Index in the cell array of this cell. (Zero-based id)

        Returns:
            (:obj:`list[float]`): The cells value (without the weight applied) - [Vx, Vy]
        """
        value = self.cell_value_data[cell_index]
        self.active_neighbor_cell_xs.append(value[0] * weight)
        self.active_neighbor_cell_ys.append(value[1] * weight)
        return value

    def interpolate_weighted_neighbors(self, weighted_sum):
        """Interpolate the weighted values of all the cells that have been appended for the current point.

        The active neighbor cell values will be cleared after calling this method.

        Args:
            weighted_sum (:obj:`float`): The sum of all the weights for all the values that have been appended for
                the current point
        Returns:
            (:obj:`list[float]`): The sum of all the weighted values that have been added for this point divided
                by weighted_sum - [Vx, Vy].
        """
        idw_value_x = sum(self.active_neighbor_cell_xs) / weighted_sum
        idw_value_y = sum(self.active_neighbor_cell_ys) / weighted_sum
        self.active_neighbor_cell_xs = []
        self.active_neighbor_cell_ys = []
        return [idw_value_x, idw_value_y]

    def set_value(self, point_index, value):
        """Manually updates a point's dataset value and updates the timestep min/max.

        Computes the magnitude of the vector components, updates the dataset timestep min/max and adds the magnitude
            to a scalar magnitude dataset.

        Args:
            point_index (:obj:`int`): Index in the point array of this point. (Zero-based id)
            value (:obj:`list[float]`): The new dataset value to assign to this point - [Vx, Vy].
        """
        idw_value_mag = math.sqrt(value[0]**2 + value[1]**2)
        self.point_value_array[point_index] = value
        self.point_value_mag_array[point_index] = idw_value_mag

    def set_timestep_dataset(self, point_value_dset, ts_time):
        """Writes interpolated vector point dataset timestep to an H5 file on disk.

        Args:
            point_value_dset: File handle pointing to the H5 dataset location where the target point dataset timestep
                will be written.
            ts_time (:obj:`float`): The timestep offset time
        """
        point_value_dset.append_timestep(time=ts_time, data=self.point_value_array)

    def set_timestep_mag_dataset(self, point_value_mag_dset, ts_time):
        """Writes interpolated vector magnitude point dataset timestep to an H5 file on disk.

        Args:
            point_value_mag_dset: File handle pointing to the H5 dataset location where the target point dataset
                timestep will be written.
            ts_time (:obj:`float`): The timestep offset time
        """
        point_value_mag_dset.append_timestep(time=ts_time, data=self.point_value_mag_array)


class InterpolationWriter:
    """Interpolates cell based datasets to nodes of the cells. The new datasets are written to an XMDF H5 file."""

    # def __init__(self, a_arr_pt_points, a_arr_pt_zs, a_arr_cell_points, a_arr_cell_zs, a_pt_to_poly,
    #             a_border_flags, cell_areas):  # IDW version
    def __init__(self, num_pts, pt_zs, cell_zs, pt_to_poly, border_flags, cell_areas, has_pressure_bc, mesh_uuid):
        """Initialize the interpoplation writer.

        Args:
            num_pts (:obj:`int`): Number of target mesh nodes
            pt_zs (:obj:`numpy.ndarray[float]`): List of the target mesh node z elevations
            cell_zs (:obj:`numpy.ndarray[float]`): List of the source cell center z elevations
            pt_to_poly (:obj:`dict`): Map where key = zero based point id and value = list of zero based polygon ids of
                polygons the key point belongs to
            border_flags (:obj:`numpy.ndarray[bool]`): List of flags that is parallel with a_arr_pt_points. If the
                flag is True, the corresponding point in a_arr_pt_points is a border point.
            cell_areas (:obj:`numpy.ndarray[float]`): List of the source cell areas.
            has_pressure_bc (:obj:`bool`): True if the boundary condition coverage contains pressure zone arcs.
            mesh_uuid (:obj:`str`): UUID of the mesh in the simulation
        """
        self.datasets_grp = None
        self.mesh_uuid = mesh_uuid
        self.time_dset = None
        self.max_time_dset = None
        self.null_val = -999.0
        # self.arr_cell_points = a_arr_cell_points  # Only used for IDW
        self.cell_zs = cell_zs
        # self.arr_pt_points = a_arr_pt_points  # Only used for IDW
        self.point_zs = pt_zs
        self.num_pts = num_pts
        self.pt_to_poly = pt_to_poly
        self.border_flags = border_flags  # Node id order. 0=interior, 1=mesh boundary
        self.has_pressure_bc = has_pressure_bc

        # Sediment transport solution flags.
        self.has_bed_elev = False  # True if optional bed elevation solution dataset exists
        self.has_conc_t = False  # True if optional sediment concentration solution dataset exists
        self.has_d50 = False  # True if optional bed material D50 size solution dataset exists
        self.has_ero_dep = False  # True if optional erosion and deposition solution dataset exists

        self.cell_areas = cell_areas  # for area-weighted interpolation

        self.testing = False  # Hard-coded dataset UUIDs if testing

    def interpolate_and_write(self, pt_data_file, cell_dsets):
        """Interpolate a cell based dataset to the nodes of the cells and write the new dataset to an XMDF H5 file.

        self.times must be set before calling this method.

        Args:
            pt_data_file (:obj:`str`): Path to the point-based dataset file that will be written
            cell_dsets (:obj:`dict`): Map where key = dataset name and value = DatasetReader
        """
        # special dsets
        # used to compute the actual WSE
        pressure_index_key = ''
        depth_key = ''
        elev_key = ''
        # velocities at walls reduced to zero
        vel_key = ''
        vel_mag_key = ''
        not_special_dset_keys = []
        max_dset_keys = []
        time_units = None
        for key in cell_dsets:
            if not time_units and cell_dsets[key]:  # Snag the time units if we haven't
                time_units = cell_dsets[key].time_units

            # special
            if key.startswith('Water_Depth'):
                depth_key = key
            elif key.startswith('Water_Elev+Pressure'):  # Should not exist in the end if no pressure zone exists.
                elev_with_pressure_key = key
            elif key.startswith('Water_Elev'):  # If pressure zone exists, WSE = node_z + depth.
                elev_key = key
            elif key.startswith('Vel_Mag'):
                vel_mag_key = key
            elif key.startswith('Velocity'):
                vel_key = key
            elif key.startswith('Pressure'):
                pressure_index_key = key
            elif key.startswith('Max_') or key.startswith('time_'):
                max_dset_keys.append(key)
            else:
                not_special_dset_keys.append(key)

        vel_dset = DatasetWriter(
            h5_filename=pt_data_file,
            name=vel_key,
            geom_uuid=self.mesh_uuid,
            num_components=2,
            time_units=time_units,
            null_value=self.null_val,
            overwrite=False,
            dset_uuid=self._uuid_for_dataset()
        )
        depth_dset = DatasetWriter(
            h5_filename=pt_data_file,
            name=depth_key,
            geom_uuid=self.mesh_uuid,
            time_units=time_units,
            null_value=self.null_val,
            overwrite=False,
            dset_uuid=self._uuid_for_dataset()
        )
        vel_mag_dset = DatasetWriter(
            h5_filename=pt_data_file,
            name=vel_mag_key,
            geom_uuid=self.mesh_uuid,
            time_units=time_units,
            null_value=self.null_val,
            overwrite=False,
            dset_uuid=self._uuid_for_dataset()
        )
        if self.has_pressure_bc:
            elev_with_pressure_dset = DatasetWriter(
                h5_filename=pt_data_file,
                name=elev_with_pressure_key,
                geom_uuid=self.mesh_uuid,
                time_units=time_units,
                null_value=self.null_val,
                overwrite=False,
                dset_uuid=self._uuid_for_dataset()
            )
            elev_without_pressure_dset = DatasetWriter(
                h5_filename=pt_data_file,
                name=elev_key,
                geom_uuid=self.mesh_uuid,
                time_units=time_units,
                null_value=self.null_val,
                overwrite=False,
                dset_uuid=self._uuid_for_dataset()
            )
        else:
            elev_without_pressure_dset = DatasetWriter(
                h5_filename=pt_data_file,
                name=elev_key,
                geom_uuid=self.mesh_uuid,
                time_units=time_units,
                null_value=self.null_val,
                overwrite=False,
                dset_uuid=self._uuid_for_dataset()
            )

        not_special_dsets = []  # [(point DatasetWriter, cell DatasetReader, timestep values, IDW timestep values)]
        for key in not_special_dset_keys:
            name = key
            if name.startswith('Strs'):
                name = name.replace('Strs', 'B_Stress')
            writer = DatasetWriter(
                h5_filename=pt_data_file,
                name=name,
                geom_uuid=self.mesh_uuid,
                time_units=time_units,
                null_value=self.null_val,
                overwrite=False,
                dset_uuid=self._uuid_for_dataset()
            )
            not_special_dsets.append([writer, cell_dsets[key], None, None])

        max_dsets = []  # [(point DatasetWriter, cell DatasetReader, timestep values, IDW timestep values)]
        for key in max_dset_keys:
            writer = DatasetWriter(
                h5_filename=pt_data_file,
                name=key,
                geom_uuid=self.mesh_uuid,
                time_units=time_units,
                null_value=self.null_val,
                overwrite=False,
                dset_uuid=self._uuid_for_dataset()
            )
            max_dsets.append([writer, cell_dsets[key], None, None])

        cell_vel_dset = cell_dsets[vel_key]
        if self.has_pressure_bc:
            cell_elev_dset = cell_dsets[elev_with_pressure_key]
        else:
            cell_elev_dset = cell_dsets[elev_key]
        cell_depth_dset = cell_dsets[depth_key]
        cell_pressure_index_dset = None
        if pressure_index_key:
            cell_pressure_index_dset = cell_dsets[pressure_index_key]

        # compute_wse_with_pressure = False  # True if any of the cells in any of the timesteps have pressure.
        num_times = len(self.time_dset)
        for time_idx in range(num_times):
            # Get the source cell data
            max_dset_act_vals = None
            if time_idx == 0:
                for item in max_dsets:
                    item[TS_VALUES] = TimestepValues(item[CELL_DSET].values[time_idx], self.null_val, self.num_pts)
                if len(max_dsets) > 0:  # activity will be the same for all max datasets
                    max_dset_act_vals = max_dsets[0][TS_VALUES].cell_value_data

            vel_ts_values = VectorTimestepValues(cell_vel_dset.values[time_idx], self.null_val, self.num_pts)
            depth_values_array = cell_depth_dset.values[time_idx]
            depth_ts_values = TimestepValues(depth_values_array, self.null_val, self.num_pts)
            pressure_index_ts_values = None
            if cell_pressure_index_dset:
                pressure_index_ts_values = TimestepValues(
                    cell_pressure_index_dset.values[time_idx], self.null_val, self.num_pts
                )
            for item in not_special_dsets:
                item[TS_VALUES] = TimestepValues(item[CELL_DSET].values[time_idx], self.null_val, self.num_pts)

            if self.has_pressure_bc:
                pressure_index_array = pressure_index_ts_values.cell_value_data
                elev_with_pressure_array = cell_elev_dset.values[time_idx]
                elev_with_pressure_ts_values = TimestepValues(elev_with_pressure_array, self.null_val, self.num_pts)
                elev_without_pressure_array = copy.deepcopy(elev_with_pressure_array)
                for cell_idx, cell_wse in enumerate(elev_with_pressure_array):
                    # If the cell Z + cell depth value is not within 1/10 of a foot or meter of water surface elev, we
                    # assume that we are in a pressure zone. If this condition is true for any location in the mesh,
                    # we will compute an additional "true WSE" dataset along with the WSE dataset we always create
                    # that has pressure included.
                    if cell_wse == self.null_val:
                        continue  # Cell is dry, don't worry about it.

                    # If sediment transport is turned on, use the transient cell bed elevation in place of the cell Z
                    # when performing pressure zone check.
                    # if self.has_bed_elev:
                    #     cell_bed_elev = bed_elev_values_array[cell_idx]
                    #     if not math.isclose(cell_bed_elev + depth_values_array[cell_idx], cell_wse, abs_tol=0.1):
                    #         elev_without_pressure_array[cell_idx] = cell_bed_elev + depth_values_array[cell_idx]
                    #         compute_wse_with_pressure = True
                    # elif not math.isclose(self.cell_zs[cell_idx] + depth_values_array[cell_idx], cell_wse, abs_tol=0.1):  # noqa
                    # if not math.isclose(self.cell_zs[cell_idx] + depth_values_array[cell_idx], cell_wse, abs_tol=0.1):
                    #     elev_without_pressure_array[cell_idx] = self.cell_zs[cell_idx] + depth_values_array[cell_idx]
                    #     compute_wse_with_pressure = True
                    if pressure_index_array[cell_idx] == 1.0:
                        elev_without_pressure_array[cell_idx] = self.cell_zs[cell_idx] + depth_values_array[cell_idx]
                        # compute_wse_with_pressure = True
                elev_without_pressure_ts_values = TimestepValues(
                    elev_without_pressure_array, self.null_val, self.num_pts
                )
            else:
                elev_without_pressure_ts_values = TimestepValues(
                    cell_elev_dset.values[time_idx], self.null_val, self.num_pts
                )

            for i in range(self.num_pts):
                max_neighbor_cell_elev_with_pressure = -1.0 * sys.float_info.max
                max_neighbor_cell_elev_without_pressure = -1.0 * sys.float_info.max
                min_neighbor_cell_depth = sys.float_info.max
                has_dry_cell = False
                has_wet_cell = False
                inverse_distance_sum = 0.0
                if self.has_pressure_bc:
                    elev_with_pressure_ts_values.reset()
                elev_without_pressure_ts_values.reset()
                depth_ts_values.reset()
                for item in not_special_dsets:
                    item[TS_VALUES].reset()

                # do interpolation for max data sets
                if time_idx == 0 and len(max_dsets) > 0:
                    for item in max_dsets:
                        if item[TS_VALUES] is not None:
                            item[TS_VALUES].reset()
                    for cell in self.pt_to_poly[i]:
                        # Area weighted
                        cell_is_active = max_dset_act_vals[cell] != self.null_val
                        if cell_is_active:
                            inverse_distance = self.cell_areas[cell]
                            inverse_distance_sum += inverse_distance
                            for item in max_dsets:
                                item[TS_VALUES].append_weighted_value(inverse_distance, cell)

                    for item in max_dsets:
                        if inverse_distance_sum == 0.0:
                            item[TS_VALUES].set_value(i, self.null_val)
                            continue
                        item[TS_VALUES_IDW] = item[TS_VALUES].interpolate_weighted_neighbors(inverse_distance_sum)
                        item[TS_VALUES].set_value(i, item[TS_VALUES_IDW])

                inverse_distance_sum = 0.0
                for cell in self.pt_to_poly[i]:
                    if self.has_pressure_bc:
                        cell_is_active = elev_with_pressure_ts_values.cell_value_data[cell] != self.null_val
                    else:
                        cell_is_active = elev_without_pressure_ts_values.cell_value_data[cell] != self.null_val

                    if cell_is_active:
                        # One of the node's neighboring polygons is wet.
                        has_wet_cell = True
                        # IDW
                        # dx = self.arr_cell_points[cell][0] - self.arr_pt_points[i][0]
                        # dy = self.arr_cell_points[cell][1] - self.arr_pt_points[i][1]
                        # inverse_distance = 1.0/math.sqrt(dx**2 + dy**2)

                        # These are some other interpolation algorithms we tried. Not much difference in the case we
                        # were looking at

                        # Average
                        # inverse_distance = 1.0

                        # Area weighted
                        inverse_distance = self.cell_areas[cell]

                        inverse_distance_sum += inverse_distance

                        if self.has_pressure_bc:
                            neighbor_cell_elev_with_pressure = elev_with_pressure_ts_values.append_weighted_value(
                                inverse_distance, cell
                            )
                        neighbor_cell_elev_without_pressure = elev_without_pressure_ts_values.append_weighted_value(
                            inverse_distance, cell
                        )
                        neighbor_cell_depth = depth_ts_values.append_weighted_value(inverse_distance, cell)
                        vel_ts_values.append_weighted_value(inverse_distance, cell)
                        for item in not_special_dsets:
                            item[TS_VALUES].append_weighted_value(inverse_distance, cell)

                        if self.has_pressure_bc:
                            max_neighbor_cell_elev_with_pressure = max(
                                neighbor_cell_elev_with_pressure, max_neighbor_cell_elev_with_pressure
                            )
                        max_neighbor_cell_elev_without_pressure = max(
                            neighbor_cell_elev_without_pressure, max_neighbor_cell_elev_without_pressure
                        )
                        min_neighbor_cell_depth = min(neighbor_cell_depth, min_neighbor_cell_depth)
                    else:
                        has_dry_cell = True

                # Interpolate from cell centers to all nodes that belong to at least one wet cell.
                if has_wet_cell:
                    if np.isnan(depth_ts_values.active_sum):
                        pass
                    # compute inverse distance weighted values
                    idw_depth = depth_ts_values.interpolate_weighted_neighbors(inverse_distance_sum)
                    idw_vel = vel_ts_values.interpolate_weighted_neighbors(inverse_distance_sum)
                    if self.has_pressure_bc:
                        idw_elev_with_pressure = elev_with_pressure_ts_values.interpolate_weighted_neighbors(
                            inverse_distance_sum
                        )
                    idw_elev_without_pressure = elev_without_pressure_ts_values.interpolate_weighted_neighbors(
                        inverse_distance_sum
                    )

                    for item in not_special_dsets:
                        item[TS_VALUES_IDW] = item[TS_VALUES].interpolate_weighted_neighbors(inverse_distance_sum)
                    # Nothing adjusts these, so they can be set immediately.
                    for item in not_special_dsets:
                        item[TS_VALUES].set_value(i, item[TS_VALUES_IDW])

                    # Adjust WSE, depth, velocity for edge cases:
                    #     1) Adjust depth and velocity for points on wet/dry or domain boundaries. Adjust depth so
                    #        depth = idw_elev_with_pressure - point_z. Scale velocities on wet/dry boundaries by
                    #        (idw_elev_with_pressure - point_z) / idw_depth. Set velocities on domain boundaries to 0.0.
                    #     2) If going over a lip, set the WSE to be either the node Z plus the idw_depth OR the maximum
                    #        maximum of the neighboring cell's elevations (don't let the node go higher than neighbors).
                    #     3) Ensure that depth always equals WSE - node Z in all cases.

                    # we will wait until we have confidence in sediment transport
                    # if self.has_ero_dep:
                    # # If sediment transport is enabled, set the effective bed elevation to be node_Z + IDW_erosion.
                    #     bed_effective = self.point_zs[i] - idw_erosion
                    # else:
                    #     bed_effective = self.point_zs[i]

                    bed_effective = self.point_zs[i]
                    alt_depth = idw_elev_without_pressure - bed_effective
                    if self.has_pressure_bc:
                        elev_with_pressure_reset = idw_elev_with_pressure
                    elev_without_pressure_reset = idw_elev_without_pressure
                    vel_reset = idw_vel
                    is_border_point = self.border_flags[i] == 1
                    # ATC 10-15-2018: We used to only adjust the depth and velocity of points on the edge of the mesh
                    # domain if the interpolated WSE for that point is less than its Z. We now adjust the depth of
                    # all points on the mesh domain boundary and set their velocities to 0.0.
                    if has_dry_cell or is_border_point:
                        # This is a point on a wet dry boundary OR This is a mesh border point
                        # Recompute the depth to be idw_elev_without._pressure - point_z
                        depth_reset = alt_depth
                        # We were ensuring a positive depth, but this leads to inconsistency between depth + Z and
                        # WSE. It has been decided that the depth should go negative in this situation. (Mantis 11473).
                        # depth_reset = max(alt_depth, 0)
                        if is_border_point:
                            # Set velocities to 0.0 on the boundary of the mesh domain.
                            vel_reset = [0.0, 0.0]
                        else:
                            # Scale velocity at the wet/dry boundary edge to be alt_depth/idw_depth but clamped to be
                            # between 0.0 and 1.0
                            scale_factor = max(min(1.0, depth_reset / idw_depth), 0.0)
                            vel_reset = [idw_vel[0] * scale_factor, idw_vel[1] * scale_factor]
                    elif alt_depth < idw_depth * 0.5:
                        # This point is on the lip of a dam or levy
                        # alt_depth < idw_depth == idw_elev_with_pressure - point_z < idw_depth
                        #                       == idw_elev_with_pressure < point_z + idw_depth
                        elev_without_pressure_reset = min(
                            max_neighbor_cell_elev_without_pressure, bed_effective + idw_depth
                        )
                        # If this is a lip in a pressure zone, we think we would need to add the pressure back in here
                        # to the elev_with_pressure_reset. However, we do not think that it is possible to be in a
                        # pressure zone while going over a lip, so we didn't bother.
                        if self.has_pressure_bc:
                            elev_with_pressure_reset = min(
                                max_neighbor_cell_elev_with_pressure, bed_effective + idw_depth
                            )
                        depth_reset = elev_without_pressure_reset - bed_effective
                    else:  # Not an edge case, just ensure depth + node Z = WSE
                        depth_reset = elev_without_pressure_reset - bed_effective

                    depth_ts_values.set_value(i, depth_reset)
                    vel_ts_values.set_value(i, vel_reset)
                    if self.has_pressure_bc:
                        elev_with_pressure_ts_values.set_value(i, elev_with_pressure_reset)
                    elev_without_pressure_ts_values.set_value(i, elev_without_pressure_reset)

            timestep_values = [
                (vel_ts_values, vel_dset),
                (depth_ts_values, depth_dset),
                (elev_without_pressure_ts_values, elev_without_pressure_dset),
            ]
            if self.has_pressure_bc:
                timestep_values.append((elev_with_pressure_ts_values, elev_with_pressure_dset))
            for item in not_special_dsets:
                timestep_values.append((item[TS_VALUES], item[PT_DSET]))
            if time_idx == 0:
                for item in max_dsets:
                    timestep_values.append((item[TS_VALUES], item[PT_DSET]))

            ts_time = self.time_dset[time_idx]
            for dset_timestep, point_dset in timestep_values:
                if point_dset.name.startswith('Max_') or point_dset.name.startswith('time_of_Max_'):
                    dset_timestep.set_timestep_dataset(point_dset, self.max_time_dset[time_idx])
                else:
                    dset_timestep.set_timestep_dataset(point_dset, ts_time)
            # Write the vector magnitude dataset
            vel_ts_values.set_timestep_mag_dataset(vel_mag_dset, ts_time)
            print('Processed timestep: {} out of {}'.format(time_idx + 1, num_times))
            sys.stdout.flush()

        # this should not happen anymore because there is a new output dataset to tell us if pressure really
        # exists in the solution

        # Delete the "Water_Elev+Pressure" dataset from the file if none of the timesteps have pressure. If there is
        # a pressure zone in the solution, we will create the "Water_Elev+Pressure" dataset in addition to a
        # "Water_Elev" dataset that does not include pressure.
        # if self.has_pressure_bc and not compute_wse_with_pressure:
        #     grp_name = f'Datasets/{elev_with_pressure_key}'
        #     if grp_name in pt_data_file:
        #         del pt_data_file[grp_name]

        # Flush data to disk and close file handles.
        vel_dset.appending_finished()
        depth_dset.appending_finished()
        elev_without_pressure_dset.appending_finished()
        vel_mag_dset.appending_finished()
        if self.has_pressure_bc:
            elev_with_pressure_dset.appending_finished()
        for non_special in not_special_dsets:
            non_special[0].appending_finished()
        for max_dset in max_dsets:
            max_dset[0].appending_finished()

    def _uuid_for_dataset(self):
        """Get a random UUID for a new dataset or a hard-coded one if testing.

        Returns:
            str: See description
        """
        if self.testing:
            return '11111111-1111-1111-1111-111111111111'
        return str(uuid.uuid4())  # pragma no cover


def get_polygon_centroid(points):
    """Compute the 2D centroid of a 3D polygon. (Z values are computed, but not used in finding the centroid.).

    Args:
        points: (:obj:`list[tuple(float)]`): Polygon node locations, e.g. [(x, y, z)]

    Returns:
        (:obj:`list[float]`): The x, y, z location of the polygon centroid, e.g. [0.0, 1.0, 2.0]
    """
    if len(points) < 1:  # pragma: no cover
        return [0.0, 0.0, 0.0]
    # loop through the pts
    centroid = [0.0, 0.0, 0.0]
    for pt in points:
        centroid[0] += pt[0]
        centroid[1] += pt[1]
        centroid[2] += pt[2]
    centroid[0] /= len(points)
    centroid[1] /= len(points)
    centroid[2] /= len(points)

    # Calculate the area
    x = [pt[0] for pt in points]
    y = [pt[1] for pt in points]
    return centroid, get_polygon_area(x, y)


def get_polygon_area(x_coords, y_coords):
    """Find the 2-D area of a polygon using the Shoelace formula.

    See accepted answer: https://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates

    Args:
        x_coords (:obj:`list[float]`): X coordinates of the polygon points. In clockwise or counterclockwise
            order.
        y_coords (:obj:`list[float]`): Y coordinates of the polygon points. Parallel with x_coords.

    Returns:
        float: 2-D area of the polygon
    """
    return 0.5 * np.abs(np.dot(x_coords, np.roll(y_coords, 1)) - np.dot(y_coords, np.roll(x_coords, 1)))


def post_process(data_dict):
    """Top-level function that starts the post processing of an SRH-2D solution file.

    Args:
        data_dict: Dict containing info from the srh_post.json file, including:

            xmdfc_filename (:obj:`str`): Path and filename of the cell based solution XMDF file.
            If the path is relative, it must be from the executing process's directory.

            srhgeom_filename (:obj:`str`): Path and filename of the *.srhgeom input file of the simulation run.
            If the path is relative, it must be from the executing process's directory.

            bc_arc_ids (:obj:`list[int]`): Ids of the non-wall boundary condition arcs. SRH-2D ids, not XMS.

            has_pressure_bc (:obj:`bool`): True if there are Pressure zone arc boundary conditions in the BC coverage.

            mesh_uuid (:obj:`str`): UUID of the mesh in the simulation

    Returns:
        (:obj:`str`): Filepath to the node-based XMDF solution dataset file
    """
    xmdfc_filename = data_dict['OUT_FILE']
    srhgeom_filename = os.path.basename(data_dict['GEOM_FILE'])  # Always assumed to be in current working directory.
    bc_arc_ids = data_dict['ARC_IDS']
    # this is no long needed with version 3.3.0 because SRH will write a dataset that tells which cells are under
    # pressure
    # has_pressure_bc = data_dict['PRESSURE']
    mesh_uuid = data_dict['GRID_UUID']
    update_xmdfc_file(xmdfc_filename, mesh_uuid)

    geom_reader = GeomReader()
    geom_reader.read(srhgeom_filename)
    bc_arc_nodes = geom_reader.data['node_strings']
    ugrid = geom_reader.cogrid.ugrid

    # read the xmdfc.h5 file
    cell_data_file = h5py.File(xmdfc_filename, "r")
    dataset_groups = cell_data_file.keys()
    if 'Datasets' in dataset_groups:
        dataset_groups = cell_data_file['Datasets'].keys()

    cell_dsets = {}  # {group_name: DatasetReader}
    times = np.array([])
    max_times = np.array([])

    # first we want to see if we have a Pressure_Index dataset. This will tell us if we have pressure.
    has_pressure_bc = False
    if 'Pressure_Index' in dataset_groups:
        maxs = cell_data_file['Datasets/Pressure_Index/Maxs']
        max_val = max(maxs)
        if max_val > 0:
            has_pressure_bc = True

    for key in dataset_groups:
        if key != 'File Type' and key != 'File Version' and key != 'Guid':  # We have a dataset
            if key.startswith('Water_Elev') and has_pressure_bc:
                cell_dsets[key] = None
                cell_dsets[key.replace('Water_Elev',
                                       'Water_Elev+Pressure')] = DatasetReader(xmdfc_filename, dset_name=key)
            else:
                cell_dsets[key] = DatasetReader(xmdfc_filename, dset_name=key)
            if len(times) == 0:
                times = cell_dsets[key].times[:]
                max_times = np.array([times[-1]])

    pt_points = []
    pt_zs = []
    pt_to_poly = {}
    locs = ugrid.locations
    for idx, loc in enumerate(locs):
        pt_points.append([loc[0], loc[1]])
        pt_zs.append(loc[2])
        pt_to_poly[idx] = ugrid.get_point_adjacent_cells(idx)

    cell_center_zs = []
    cell_areas = []
    all_cell_pts = {}
    for cell_idx in range(ugrid.cell_count):
        _, center = ugrid.get_cell_centroid(cell_idx)
        pts = [locs[p] for p in ugrid.get_cell_points(cell_idx)]
        area = gm.polygon_area_2d(pts)
        center_z = 0.0
        for p in pts:
            center_z += p[2]
        center_z = center_z / len(pts)
        cell_center_zs.append(center_z)
        cell_areas.append(area)
        all_cell_pts[cell_idx] = ugrid.get_cell_points(cell_idx)

    # Copy data to numpy arrays before interpolating. We noticed speed and memory performance improvements when using
    # numpy arrays compared to native Python lists. Seems like a silly copy here, but we don't know the array sizes
    # initially. numpy arrays are fixed size.
    # arr_cell_centers = np.array(cell_center_points)  # Need to store these if doing IDW
    arr_cell_zs = np.array(cell_center_zs)
    # arr_pt_points = np.array(pt_points)  # Need to store these if doing IDW
    arr_pt_zs = np.array(pt_zs)
    arr_cell_areas = np.array(cell_areas)

    # find the border
    num_pts = len(pt_points)
    border_points = find_mesh_boundary(num_pts, all_cell_pts, pt_to_poly)
    # Don't consider nodes on boundary condition nodestrings to be on the border. This would result in velocities
    # being scaled to zero in these locations. We do not want to make this adjustment where water is entering or
    # exiting the domain.
    for arc_id, node_string in bc_arc_nodes.items():
        if len(node_string) >= 2 and arc_id in bc_arc_ids:
            for node_id in node_string:
                border_points[node_id] = 0

    xmdf_filename = xmdfc_filename.replace('_XMDFC.h5', '_XMDF.h5')
    filesystem.removefile(xmdf_filename)  # Clean out old one if it exists
    writer = InterpolationWriter(
        num_pts, arr_pt_zs, arr_cell_zs, pt_to_poly, border_points, arr_cell_areas, has_pressure_bc, mesh_uuid
    )
    writer.testing = data_dict.get('testing', False)
    writer.time_dset = times
    writer.max_time_dset = max_times
    writer.interpolate_and_write(xmdf_filename, cell_dsets)
    cell_data_file.close()
    return xmdf_filename


def find_mesh_boundary(num_pts, all_cell_pts, pt_to_poly):
    """Find the boundary nodes of a mesh.

    Args:
        num_pts (:obj:`int`): The number of nodes in the mesh
        all_cell_pts (:obj:`dict`): Dict of list of ints where key is the polygon id. The value is a list of the
            zero-based point ids of the polygon's points in either clockwise or counterclockwise order.
        pt_to_poly(:obj:`dict`): Dictionary where key = zero-based id of a node and value = a list of zero-based
            polygon ids that the point belongs to.

    Returns:
        (:obj:`list[int]`): Point boundary flag array. Boundaries = 1, interiors = 0. In point id order.
    """
    border_points = np.array([0 for _ in range(num_pts)])
    # Check each point of the mesh against all the others. If any two points share one and only one polygon and they
    # are next to each other, they are on the mesh boundary.
    for pt1 in range(num_pts):
        # Point only belongs to one polygon. It is a corner.
        if len(pt_to_poly[pt1]) == 1:
            border_points[pt1] = 1
            continue
        if border_points[pt1] == 1:
            continue
        other_pt_border = {}
        for poly in pt_to_poly[pt1]:
            cell_pts = all_cell_pts[poly]
            for cell_pt_idx in range(len(cell_pts)):
                if cell_pts[cell_pt_idx] == pt1:
                    prev_idx = cell_pt_idx - 1
                    if prev_idx < 0:
                        prev_idx = len(cell_pts) - 1
                    prev_pt = cell_pts[prev_idx]
                    next_idx = cell_pt_idx + 1
                    if next_idx == len(cell_pts):
                        next_idx = 0
                    next_pt = cell_pts[next_idx]
                    if prev_pt in other_pt_border:
                        other_pt_border[prev_pt] = False
                    else:
                        other_pt_border[prev_pt] = True
                    if next_pt in other_pt_border:
                        other_pt_border[next_pt] = False
                    else:
                        other_pt_border[next_pt] = True
                    break
        for b_pt, is_b in other_pt_border.items():
            if is_b:
                border_points[b_pt] = 1
                border_points[pt1] = 1
    return border_points


def update_xmdfc_file(fname, geom_uuid):
    """Reformat file for XMS reading.

    Args:
        fname (:obj:`str`): h5 filename
        geom_uuid (:obj:`str`): string for geometry UUID
    """
    if not os.path.isfile(fname):
        return

    with h5py.File(fname, 'r+') as solution_file:
        if 'Datasets' in solution_file:
            return

        grp_datasets = solution_file.create_group('Datasets')
        multi_datasets = 'MULTI DATASETS'
        ascii_list = [multi_datasets.encode("ascii", "ignore")]
        grp_datasets.attrs.create('Grouptype', data=ascii_list, shape=(1, ), dtype='S15')

        if geom_uuid:
            ascii_list = [geom_uuid.encode("ascii", "ignore")]
            grp_datasets.create_dataset('Guid', shape=(1, ), dtype='S37', data=ascii_list)

        # get the activity from the water elevation data sets
        if 'Water_Elev_ft' in solution_file:
            we_str = 'Water_Elev_ft'
        elif 'Water_Elev_m' in solution_file:
            we_str = 'Water_Elev_m'
        else:
            we_str = ''

        we_ds = None
        max_activity = None
        if we_str:
            new_path = f'Datasets/{we_str}'
            solution_file.move(we_str, new_path)
            we_ds = solution_file[new_path]
            solution_file[new_path].attrs['DatasetLocation'] = [0]
            we_max = f'Max_{we_str}'
            if we_max in solution_file:  # get activity of all timesteps to be used with max datasets
                max_activity = we_ds['Values'][0]
                max_activity[max_activity != -999.0] = 1.
                max_activity[max_activity == -999.0] = 0.
                times = we_ds['Times']
                for i in range(1, len(times)):
                    vals = we_ds['Values'][i]
                    vals[vals != -999.0] = 1.
                    vals[vals == -999.0] = 0.
                    max_activity = max_activity + vals

        skip_items = ['file version', 'file type', 'datasets', we_str.lower()]
        for ds in solution_file:
            if ds.lower() in skip_items:
                continue

            solution_file[ds].attrs['DatasetLocation'] = [0]
            new_name = ds if ds != 'B_Stress_lb_p_ft' else 'B_Stress_lb_p_ft2'
            new_path = f'Datasets/{new_name}'
            solution_file.move(ds, new_path)  # noqa B038
            # update the data set values with the null value from water elevation
            if ds.lower().startswith('max_') or ds.lower().startswith('time_'):
                ds_val = solution_file[new_path]['Values'][0]
                ds_val[max_activity == 0.] = -999.0
                solution_file[new_path]['Values'][0] = ds_val
                continue

            times = solution_file[new_path]['Times']
            for i in range(len(times)):
                we = we_ds['Values'][i]
                if len(times) == 1:
                    we = we_ds['Values'][-1]
                ds_val = solution_file[new_path]['Values'][i]
                ds_val[we == -999.0] = -999.0
                solution_file[new_path]['Values'][i] = ds_val
