"""CreateWseDepthRaster class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os
from pathlib import Path

# 2. Third party modules
import numpy as np
from scipy.ndimage import maximum_filter, minimum_filter

# 3. Aquaveo modules
from xms.constraint.ugrid_activity import CellToPointActivityCalculator
from xms.core.filesystem import filesystem
from xms.gdal.rasters import RasterInput, RasterOutput
from xms.gdal.utilities import gdal_utils as gu
from xms.gdal.utilities import gdal_wrappers as gw
from xms.gdal.vectors import VectorOutput
from xms.mesher.meshing import mesh_utils

# 4. Local modules
from xms.tool.algorithms.coverage.grid_cell_to_polygon_coverage_builder import GridCellToPolygonCoverageBuilder
from xms.tool.algorithms.datasets.merge_datasets import get_output_dataset_values


class CreateWseDepthRaster():
    """Algorithm to make a WSE raster."""

    def __init__(self, elevation_raster, co_grid, wse_dataset, timestep, num_cells_extrapolate,
                 output_wse_filename, output_depth_filename, wkt, vertical_datum, vertical_units, logger):
        """Initializes the class."""
        self._elevation_raster = elevation_raster
        self._co_grid = co_grid
        self._ugrid = None
        if self._co_grid:
            self._ugrid = self._co_grid.ugrid
        self._wse_dataset = wse_dataset
        self._timestep = timestep
        self._num_cells_extrapolate = num_cells_extrapolate
        self._output_wse_filename = output_wse_filename
        self._output_depth_filename = output_depth_filename
        self._horizontal_wkt = wkt
        self._vertical_datum = vertical_datum
        self._vertical_units = vertical_units
        self._logger = logger

    def create(self):
        """Create a WSE Raster.

        Returns:
            (bool): True if successful.
        """
        no_data = -99999
        elevation_values = self._elevation_raster.get_raster_values()
        e_no_data = self._elevation_raster.nodata_value
        vm = gu.get_vertical_multiplier(gu.get_vert_unit_from_wkt(self._elevation_raster.wkt), self._vertical_units)
        if e_no_data is not None:
            elevation_values = np.where(np.isclose(elevation_values, e_no_data), e_no_data,
                                        elevation_values * vm)
        else:
            elevation_values = elevation_values * vm

        if self._wse_dataset.num_activity_values is not None and self._wse_dataset.num_activity_values != \
                self._wse_dataset.num_values:
            self._wse_dataset.activity_calculator = CellToPointActivityCalculator(self._ugrid)
        ts_data, ts_activity = self._wse_dataset.timestep_with_activity(self._timestep)
        no_act = [1] * self._ugrid.point_count
        temp_activity = ts_activity if ts_activity is not None else self._wse_dataset.activity[
            self._timestep] if self._wse_dataset.activity is not None else no_act
        point_ts_values, point_activity = get_output_dataset_values(values=ts_data, activity=temp_activity,
                                                                    grid=self._ugrid,
                                                                    activity_type='point activity',
                                                                    null_value=self._wse_dataset.null_value)
        # Interpolate WSE from the dataset:
        pts = self._ugrid.locations
        interp_pts = []
        ds_min = self._wse_dataset.mins[self._timestep]
        ds_max = self._wse_dataset.maxs[self._timestep]
        for idx, pt in enumerate(pts):
            if ts_data is None:
                interp_pts.append((pt[0], pt[1], pt[2]))  # pragma no cover
            elif self._wse_dataset.null_value and point_ts_values[idx] != self._wse_dataset.null_value:
                interp_pts.append((pt[0], pt[1], point_ts_values[idx]))
            elif not self._wse_dataset.null_value and ds_max >= point_ts_values[idx] >= ds_min:
                interp_pts.append((pt[0], pt[1], point_ts_values[idx]))
        if len(interp_pts) == 0:
            raise RuntimeError('Could not interpolate WSE dataset')

        # Identify "active pixels" = inside the active area of the WSE dataset
        _, cell_activity = get_output_dataset_values(values=ts_data, activity=temp_activity,
                                                     grid=self._ugrid, activity_type='cell activity',
                                                     null_value=self._wse_dataset.null_value)

        # Get a raster file with just the active area
        active_file = self._convert_dataset_to_raster_mask(cell_activity, no_data)
        active_ds = RasterInput(active_file)
        active_values = active_ds.get_raster_values()
        active_ds = None
        os.remove(active_file)

        # Get a raster file with the mesh boundary
        cells = [1] * self._ugrid.cell_count
        mesh_boundary_file = self._convert_dataset_to_raster_mask(cells, no_data)
        active_ds = RasterInput(mesh_boundary_file)
        mesh_boundary_values = active_ds.get_raster_values()
        active_ds = None
        os.remove(mesh_boundary_file)

        # Use gdal Grid to do the IDW interpolation
        # Make a .vrt of the interp points first, call gdal's Grid, then clip the values
        self._logger.info('Running interpolation...')
        size_func = mesh_utils.size_function_from_edge_lengths(self._ugrid)
        size_func = np.array(size_func)
        out_gdal_grid = filesystem.temp_filename(suffix='.tif')
        stem_val = Path(out_gdal_grid).stem
        out_csv_file = os.path.join(os.path.dirname(out_gdal_grid), stem_val + '.csv')
        with open(out_csv_file, 'w') as f:
            f.write('X,Y,Z\n')
            for point in interp_pts:
                f.write(f'{point[0]},{point[1]},{point[2]}\n')
        out_vrt_file = os.path.join(os.path.dirname(out_gdal_grid), stem_val + '.vrt')
        with open(out_vrt_file, 'w') as f:
            f.write('<OGRVRTDataSource>\n'
                    f'  <OGRVRTLayer name={stem_val}>\n'
                    f'    <SrcDataSource>{out_csv_file}</SrcDataSource>\n'
                    f'    <LayerSRS>{self._horizontal_wkt}</LayerSRS>\n'
                    '    <GeometryType>wkbPoint</GeometryType>\n'
                    '    <GeometryField encoding="PointFromColumns" x="X" y="Y" z="Z"/>\n'
                    '  </OGRVRTLayer>\n'
                    '</OGRVRTDataSource>\n')
        e_bounds = self._elevation_raster.get_raster_bounds()
        radius = 5.0 * np.mean(size_func)
        algorithm = f"invdistnn:power=2.0:smoothing=0.0:radius={radius}:max_points=12:min_points=0:nodata={no_data}"
        output_srs = None
        if gu.valid_wkt(self._horizontal_wkt):
            output_srs = gu.wkt_to_sr(self._horizontal_wkt)
        gw.gdal_grid(out_vrt_file, out_gdal_grid, format='GTiff', outputType=RasterInput.GDT_Float32,
                     width=self._elevation_raster.resolution[0], height=self._elevation_raster.resolution[1],
                     outputBounds=[e_bounds[0][0], e_bounds[1][1], e_bounds[1][0], e_bounds[0][1]], noData=no_data,
                     layers=[stem_val], algorithm=algorithm, outputSRS=output_srs)
        grid_ds = RasterInput(out_gdal_grid)
        wse_data = grid_ds.get_raster_values()
        grid_ds = None
        wse_data = np.where(active_values == no_data, no_data, wse_data)
        # Use the active_values raster info to mask the dataset
        # This gets us raster values for each raster cell, or NODATA.  The NODATA comes from the active region mask
        # or the interpolation.
        masked_wse = np.ma.masked_where(active_values == no_data, wse_data)
        masked_wse[masked_wse.mask] = no_data  # Set masked areas to NODATA value
        wse_data = masked_wse.data
        # Compare the WSE to ground elevation from the elevation raster
        self._logger.info('Calculating WSE raster values...')
        # Use our nodata value, not the orig
        if e_no_data is not None:
            new_wse = np.where((elevation_values != e_no_data) & (wse_data > elevation_values),
                               wse_data, no_data)
        else:
            new_wse = np.where(wse_data > elevation_values, wse_data, no_data)

        if self._num_cells_extrapolate != 0:
            # Get max value of neighboring cells on "dry" locations
            self._logger.info('Extrapolating...')
            footprint = np.asarray([  # Footprint is the touching 8 cells
                [1, 1, 1],
                [1, 0, 1],
                [1, 1, 1]
            ])
            iter_count = 0
            while iter_count < self._num_cells_extrapolate:
                iter_count += 1
                self._logger.info(f'Processing iteration {iter_count}...')
                max_val = new_wse.max()
                modified_no_data = abs(max_val) * 1.1 if max_val != no_data else no_data
                modified_wse = np.where(new_wse == no_data, modified_no_data, new_wse)
                neighbor_min_vals = np.where(new_wse == no_data,
                                             minimum_filter(modified_wse, footprint=footprint, mode='constant',
                                                            cval=modified_no_data), new_wse)
                neighbor_max_vals = np.where(new_wse == no_data,
                                             maximum_filter(new_wse, footprint=footprint, mode='constant',
                                                            cval=no_data), new_wse)
                neighbor_average_vals = \
                    np.where((neighbor_min_vals != modified_no_data) & (neighbor_max_vals != no_data),
                             (neighbor_min_vals + neighbor_max_vals) / 2.0, no_data)
                neighbor_masked = np.where(new_wse == no_data, neighbor_average_vals, new_wse)

                # Compare the WSE to ground elevation from the elevation raster
                if e_no_data is not None:
                    cur_wse = np.where((mesh_boundary_values != no_data) & (elevation_values != e_no_data) &  # noqa W504
                                       (neighbor_masked > elevation_values), neighbor_masked, no_data)
                else:
                    cur_wse = np.where((mesh_boundary_values != no_data) & (neighbor_masked > elevation_values),
                                       neighbor_masked, no_data)
                if np.array_equal(cur_wse, new_wse):
                    new_wse = cur_wse
                    break
                new_wse = cur_wse
                # Uncomment the following lines if you want to write a raster for each iteration for debugging purposes.
                # cur_raster_name, extension = os.path.splitext(self._output_wse_filename)
                # self._create_raster(cur_raster_name + f'{iter_count}' + extension, new_wse, no_data)

        # Save the WSE raster and return
        self._create_raster(self._output_wse_filename, new_wse, no_data)
        if e_no_data is not None:
            depths = np.where((new_wse != no_data)  # noqa W503
                              & (elevation_values != e_no_data)  # noqa W503
                              & (new_wse > elevation_values),  # noqa W503
                              new_wse - elevation_values,
                              no_data)
        else:
            depths = np.where((new_wse != no_data)  # noqa W503
                              & (new_wse > elevation_values),  # noqa W503
                              new_wse - elevation_values,
                              no_data)
        self._create_raster(self._output_depth_filename, depths, no_data)
        return True

    def _convert_dataset_to_raster_mask(self, dataset_vals, no_data):
        # Initialize a new raster that is identical in size/shape/resolution as the input raster.
        # We should have an array of all pixels set to "unknown".
        vo = VectorOutput()
        active_geom_file = filesystem.temp_filename(suffix='.shp')
        vo.initialize_file(active_geom_file, self._horizontal_wkt)
        # Use the dataset activity values to get the boundary
        poly_builder = GridCellToPolygonCoverageBuilder(co_grid=self._co_grid,
                                                        dataset_values=dataset_vals,
                                                        wkt=None, coverage_name='temp')
        ug_locs = self._ugrid.locations
        out_polys = poly_builder.find_polygons()
        if 1 in out_polys:
            for cur_poly in out_polys[1]:
                poly_pts = cur_poly[0]
                points = [[ug_locs[i][0], ug_locs[i][1], ug_locs[i][2]] for i in poly_pts]
                holes = []
                for k in range(1, len(cur_poly)):
                    poly_pts = cur_poly[k]
                    holes.append([[ug_locs[i][0], ug_locs[i][1], ug_locs[i][2]] for i in poly_pts])
                vo.write_polygon(points, holes)
        # Close the handles associated with the polygon shapefile
        vo = None
        # Next, use gdal RasterizeLayer to burn in the active polygon region into a raster
        # Write to the raster storing the active value cells
        raster_mask_file = filesystem.temp_filename(suffix='.tif')
        raster_template = RasterOutput(width=self._elevation_raster.resolution[0],
                                       height=self._elevation_raster.resolution[1],
                                       nodata_value=no_data, wkt=self._elevation_raster.wkt,
                                       data_type=RasterInput.GDT_Int32,
                                       geotransform=self._elevation_raster.geotransform)
        active_ds = raster_template.write_template_raster(raster_mask_file)
        gw.gdal_rasterize(active_geom_file, active_ds, burnValues=[1], bands=[1])
        active_ds = None
        return raster_mask_file

    def _create_raster(self, filename, raster_array, no_data):
        """Create the raster file with interpolated elevations.

        Args:
            filename (str): Path to the output file to write
            raster_array (numpy.array): The raster data.
            no_data (float): The NODATA value to use.
        """
        wkt = self._elevation_raster.wkt
        wkt = gu.add_vertical_to_wkt(wkt, self._vertical_datum, self._vertical_units)
        raster_output = RasterOutput(xorigin=self._elevation_raster.xorigin, yorigin=self._elevation_raster.yorigin,
                                     width=self._elevation_raster.resolution[0],
                                     height=self._elevation_raster.resolution[1],
                                     pixel_width=self._elevation_raster.pixel_width,
                                     pixel_height=self._elevation_raster.pixel_height,
                                     nodata_value=no_data, wkt=wkt)
        raster_output.write_raster(filename, raster_array)
