"""RasterFromDatasetTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import shutil

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.constraint.ugrid_activity import CellToPointActivityCalculator
from xms.core.filesystem import filesystem
from xms.extractor.ugrid_2d_data_extractor import UGrid2dDataExtractor
from xms.gdal.rasters import raster_utils as ru, RasterReproject
from xms.gdal.rasters import RasterInput, RasterOutput
from xms.gdal.utilities import gdal_utils as gu
from xms.gdal.utilities import gdal_wrappers as gw
from xms.gdal.utilities import GdalRunner
from xms.gdal.vectors import VectorOutput
from xms.interp.interpolate import interp_util
from xms.interp.interpolate import InterpIdw
from xms.tool_core import ALLOW_ONLY_POINT_MAPPED, ALLOW_ONLY_SCALARS, Argument, IoDirection, Tool

# 4. Local modules
from xms.tool.algorithms.coverage.grid_cell_to_polygon_coverage_builder import GridCellToPolygonCoverageBuilder

ARG_INPUT_DATASET = 0
ARG_INPUT_TIMESTEP = 1
ARG_INPUT_TEMPLATE_RASTER = 2
ARG_TEMPLATE_ACTIVITY = 3
ARG_TEMPLATE_OPTION = 4
ARG_TEMPLATE_RESOLUTION = 5
ARG_PIXEL_SIZE = 6
ARG_NULL_VALUE = 7
ARG_EXTRAPOLATION_WIDTH = 8
ARG_OUTPUT_RASTER = 9


class RasterFromDatasetTool(Tool):
    """Tool to convert dataset values to a raster."""
    DEFAULT_OUTPUT_MULTIPLE = 50
    DEFAULT_NUM_DIVISIONS = 10

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Raster from Dataset')
        self._ugrid = None
        self._width = None
        self._height = None
        self._wkt = None
        self._points = None
        self._ts_data = None
        self._min_pt = None
        self._max_pt = None
        self._size_delta = None
        self._extractor = None
        self._row_locs = None
        self._raster_data = None
        self._no_data = None
        self._extrapolate = None
        self._raster_activity = None
        self._extractor_locations = []
        self._extractor_indices = []
        self._output_frequency = self.DEFAULT_OUTPUT_MULTIPLE

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.dataset_argument(name='dataset', description='Dataset',
                                  filters=[ALLOW_ONLY_SCALARS, ALLOW_ONLY_POINT_MAPPED]),
            self.timestep_argument(name='time_step', description='Timestep', value=Argument.NONE_SELECTED),
            self.raster_argument(name='template_raster', description='Optional template raster', optional=True),
            self.bool_argument(name='template_activity', description='Use raster activity', value=False,
                               optional=True),
            self.string_argument(name='template_option', description='Raster template resolution option',
                                 value='Specify maximum XY resolution',
                                 choices=['Use template raster native resolution',
                                          'Specify maximum XY resolution'], optional=True),
            self.integer_argument(name='template_resolution', description='Maximum XY resolution of template raster',
                                  value=1920, min_value=1, optional=True),
            self.float_argument(name='pixel_size', description='Pixel size', value=0.0, min_value=0.0),
            self.float_argument(name='null_value', description='Null value', value=-999.0),
            self.float_argument(name='extrapolation_width', description='Extrapolation width', value=0.0,
                                min_value=0.0),
            self.raster_argument(name='output_raster', description='Output raster', io_direction=IoDirection.OUTPUT)
        ]
        return arguments

    def enable_arguments(self, arguments):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.
        """
        if arguments[ARG_INPUT_DATASET].value:
            arguments[ARG_INPUT_TIMESTEP].enable_timestep(arguments[ARG_INPUT_DATASET])
        else:
            arguments[ARG_INPUT_TIMESTEP].hide = True

        use_templates = True if arguments[ARG_INPUT_TEMPLATE_RASTER].value else False
        arguments[ARG_TEMPLATE_ACTIVITY].hide = not use_templates
        arguments[ARG_PIXEL_SIZE].hide = use_templates
        arguments[ARG_TEMPLATE_OPTION].hide = not use_templates
        arguments[ARG_TEMPLATE_RESOLUTION].hide = not use_templates
        if use_templates:
            tv = arguments[ARG_INPUT_TEMPLATE_RASTER].text_value
            template_raster = self.get_input_raster(tv) if arguments[ARG_INPUT_TEMPLATE_RASTER].value else None
            if arguments[ARG_TEMPLATE_OPTION].text_value == 'Use template raster native resolution':
                arguments[ARG_TEMPLATE_RESOLUTION].hide = True
                if template_raster is not None:
                    arguments[ARG_TEMPLATE_RESOLUTION].value = max(template_raster.resolution)
            else:
                arguments[ARG_TEMPLATE_RESOLUTION].hide = False
                if template_raster is not None:
                    arguments[ARG_TEMPLATE_RESOLUTION].value = min([max(template_raster.resolution), 1920])

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}
        arguments[ARG_INPUT_TIMESTEP].validate_timestep(arguments[ARG_INPUT_DATASET], errors)

        # Make sure output name specified
        if self.default_wkt is None:
            errors[arguments[ARG_OUTPUT_RASTER].name] = 'Must specify an output coordinate system.'

        # Make sure pixel size is not zero if the template raster option is not chosen
        if arguments[ARG_PIXEL_SIZE].value == 0.0 and not arguments[ARG_INPUT_TEMPLATE_RASTER].value:
            errors[arguments[ARG_PIXEL_SIZE].name] = 'Pixel size must be greater than zero.'

        return errors

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        cogrid = self.get_input_dataset_grid(arguments[ARG_INPUT_DATASET].text_value)
        if cogrid is None:
            self.fail("Unable to access the grid. Make sure the grid has been renumbered.")
        ugrid = cogrid.ugrid
        if ugrid.cell_count == 0:
            self.fail("The dataset's UGrid must have cells to be interpolated to a raster.")
        if ugrid.dimension_counts[3] > 0:
            self.fail("The dataset's UGrid must only have 2D cells to be interpolated to a raster.")
        dataset = self.get_input_dataset(arguments[ARG_INPUT_DATASET].text_value)
        timestep = arguments[ARG_INPUT_TIMESTEP].get_timestep(arguments[ARG_INPUT_DATASET])
        if dataset.num_activity_values is not None and dataset.num_activity_values != dataset.num_values:
            dataset.activity_calculator = CellToPointActivityCalculator(ugrid)
        ts_data, _ = dataset.timestep_with_activity(timestep)
        tv = arguments[ARG_INPUT_TEMPLATE_RASTER].text_value
        template_raster = self.get_input_raster(tv) if arguments[ARG_INPUT_TEMPLATE_RASTER].value else None
        template_activity = arguments[ARG_TEMPLATE_ACTIVITY].value
        template_option = arguments[ARG_TEMPLATE_OPTION].text_value
        template_resolution = arguments[ARG_TEMPLATE_RESOLUTION].value
        pixel_size = arguments[ARG_PIXEL_SIZE].value
        null_value = arguments[ARG_NULL_VALUE].value
        extrapolation_width = arguments[ARG_EXTRAPOLATION_WIDTH].value

        if template_raster:
            # Get the extents and size from the template raster
            raster_sr = display_sr = None
            if gu.valid_wkt(template_raster.wkt):
                _, raster_wkt = gu.remove_epsg_code_if_unit_mismatch(template_raster.wkt)
                if gu.valid_wkt(raster_wkt):
                    raster_sr = gu.wkt_to_sr(raster_wkt)
                if gu.valid_wkt(self.default_wkt):
                    display_sr = gu.wkt_to_sr(self.default_wkt)
            if raster_sr is not None and display_sr is not None and not raster_sr.IsLocal() and \
                    not display_sr.IsLocal():
                projected_raster = filesystem.temp_filename(suffix='.tif')
                gw.gdal_warp(template_raster.gdal_raster, projected_raster, srcSRS=raster_sr, dstSRS=display_sr,
                             resampleAlg=RasterReproject.GRA_Cubic)
                template_raster = None
                template_raster = RasterInput(projected_raster)
            min_pt, max_pt = template_raster.get_raster_bounds()
            min_pt = np.array(min_pt)
            max_pt = np.array(max_pt)
            pixels = list(template_raster.resolution)  # Default pixel size to raster template
            if template_option == 'Specify maximum XY resolution':
                # Resize pixels based on specified resolution (could differ from template raster's)
                i = 0 if pixels[0] > pixels[1] else 1
                pixels[i] = template_resolution
                pixel_size = (max_pt[i] - min_pt[i]) / pixels[i]
                pixels[i ^ 1] = (max_pt[i ^ 1] - min_pt[i ^ 1]) / pixel_size
                pixels_int = int((max_pt[i ^ 1] - min_pt[i ^ 1]) // pixel_size)
                pixels[i ^ 1] = pixels_int + 1 if pixels[i ^ 1] - pixels_int != 0.0 else pixels_int
        else:
            # Get the extents and size from the ugrid extents and pixel size
            min_pt, max_pt = ugrid.extents
            min_pt = np.array(min_pt)
            max_pt = np.array(max_pt)
            if extrapolation_width != 0:
                min_pt -= extrapolation_width
                max_pt += extrapolation_width

            # make pixels and extents match pixel size
            pixels = [0, 0]
            for i in range(2):
                # get remainder using float and int divide
                pixels[i] = (max_pt[i] - min_pt[i]) / pixel_size
                pixels_int = int((max_pt[i] - min_pt[i]) // pixel_size)
                # if remainder left then bump up number of pixels and increase bounds
                pixels[i] = pixels_int + 1 if pixels[i] - pixels_int != 0.0 else pixels_int
                # if pixels[i] - pixels_int != 0.0:
                #     pixels[i] = pixels_int + 1
                # else:
                #     pixels[i] = pixels_int
                max_pt[i] = pixel_size * pixels[i] + min_pt[i]

        out_path = self.get_output_raster(arguments[ARG_OUTPUT_RASTER].value)
        self.generate_raster(out_path, cogrid, ugrid, min_pt, max_pt, pixels[0], pixels[1], self.default_wkt,
                             ts_data=ts_data, no_data=null_value, extrapolation_width=extrapolation_width,
                             template_raster=template_raster, template_activity=template_activity,
                             template_option=template_option)
        out_path = ru.reproject_raster(out_path, self.get_output_raster(arguments[ARG_OUTPUT_RASTER].value),
                                       self.default_wkt, self.vertical_datum, self.vertical_units)
        self.set_output_raster_file(out_path, arguments[ARG_OUTPUT_RASTER].value)

    def _log_progress(self, row):
        """Log an info message periodically.

        Args:
            row (int): The one-based row currently being processed
        """
        # Log frequently enough to keep things interesting, but not so much that the dialog lags.
        if row % self._output_frequency == 0 or row == 1:
            if row == 1:  # Log the first iteration
                end_row = min(self._height, self._output_frequency)
            else:  # Log every self._output_frequency iteration
                end_row = min(self._height, row + self._output_frequency)
            self.logger.info(
                f'Processing points in rows {row}-{end_row} of {self._height}...'
            )

    def _set_cell_dimensions(self):
        """Computes the extents of the source geometry."""
        self.logger.info('Computing raster cell dimensions...')
        dim = np.array((self._width, self._height, 1.0))
        self._size_delta = (self._max_pt - self._min_pt) / dim
        self._size_delta = self._size_delta[:-1]  # Drop unused z coordinate

    def _create_extractor(self, null_value):
        """Set up an extractor with the grid points as the locations and the point elevations as the scalars."""
        self.logger.info('Setting up dataset extractor...')
        if self._ts_data is not None:
            ug_data = self._ts_data
        else:
            ug_data = [pt[2] for pt in self._points]
        ug_activity = [1] * len(ug_data)
        for i in range(len(ug_data)):
            if ug_data[i] == null_value:
                ug_activity[i] = 0
        if self._extractor is None:
            self._extractor = UGrid2dDataExtractor(ugrid=self._ugrid)
        self._extractor.no_data_value = null_value
        self._extractor.set_grid_point_scalars(ug_data, ug_activity, 'points')

    def _load_raster_row_locations(self, row):
        """Get the locations of raster cell centers for a row.

        Args:
            row (int): The 0-based row index

        Returns:
            list of tuple: The raster cell x,y,z centers of a row
        """
        if self._row_locs is None:
            xy_min = self._min_pt[:-1]  # Trim off unused Z-coordinate
            xy_min = xy_min + (self._size_delta[0] / 2, self._size_delta[1] / 2)
            self._row_locs = np.array([
                xy_min + (col, row) * self._size_delta for col in range(self._width)
            ])
        else:
            self._row_locs[:, 1] += self._size_delta[1]
        self._extractor.extract_locations = self._row_locs

    def _process_row(self, row):
        """Process a row of raster data.

        Args:
            row (int): The zero-based index of the row to process
        """
        one_based_row = row + 1
        self._log_progress(one_based_row)
        self._load_raster_row_locations(row)
        start_idx = row * self._width
        end_idx = one_based_row * self._width
        extracted_data = self._extractor.extract_data()
        self._raster_data[start_idx: end_idx] = extracted_data[:]
        # take values and build a data/nodata set of values

    def _interp_to_raster(self, ts_data, null_value, extrapolate_all):
        """Interpolate geometry elevations to the raster."""
        # Initialize the source geometry extents.
        self._set_cell_dimensions()
        raster_activity = np.flipud(self._raster_activity)
        activity = raster_activity.flatten()
        if extrapolate_all or (not extrapolate_all and self._elev_values is not None and self._wse_values is not None):
            # Setup the points
            pts = self._ugrid.locations
            interp_pts = []
            for idx, pt in enumerate(pts):
                if ts_data is None:
                    interp_pts.append((pt[0], pt[1], pt[2]))
                elif ts_data[idx] != null_value:
                    interp_pts.append((pt[0], pt[1], ts_data[idx]))
            if len(interp_pts) == 0:
                return
            # We should have the raster actvity already from the template raster or dataset:
            # if not extrapolate_all:
            #     # Setup the pixel activity for the new raster.
            #     self.logger.info('Setting the pixel activity for the new raster...')
            #     self._raster_activity = np.full(self._height * self._width, 1)
            #     self._raster_activity[self._elev_values == self._elev_nodata] = 0
            #     self._raster_activity[self._wse_values <= self._elev_values] = 0
            self.logger.info('Running interpolation...')
            self._raster_data = interp_util.generate_raster_idw(interp_pts, tuple(self._min_pt),
                                                                tuple(np.append(self._size_delta, 0.0)), self._height,
                                                                self._width, self._no_data, activity)
        else:
            # Use the xmsextractor to only interpolate values inside the mesh
            self._ts_data = ts_data
            self._create_extractor(null_value)
            # Process the data one row at a time. This allows us to give feedback messages during a long operation.
            self.logger.info('Extracting elevations at raster cell locations...')
            self._row_locs = None
            increment = (self._height + 1) / self.DEFAULT_NUM_DIVISIONS
            self._output_frequency = max(self.DEFAULT_OUTPUT_MULTIPLE * round(increment / self.DEFAULT_OUTPUT_MULTIPLE),
                                         self.DEFAULT_OUTPUT_MULTIPLE)
            for row in range(self._height):
                self._process_row(row)
            nan_mask = np.isnan(self._raster_data)
            self._raster_data[nan_mask] = self._no_data
            nodata_mask = np.equal(null_value, self._raster_data)
            self._raster_data[nodata_mask] = self._no_data
            activity_mask = np.equal(0, activity)
            self._raster_data[activity_mask] = self._no_data

    def _extrapolate_all(self, ts_data, no_data):
        # get locations to extrapolate to including inside and outside of ugrid
        self._interp_to_raster(ts_data, no_data, True)

    def _extrapolate_raster_data(self, extrapolation_width: float):
        """Extrapolate out from the raster data.

        Args:
            extrapolation_width (float): The distance to extrapolate out from the UGrid.
        """
        gdal_runner = GdalRunner()

        # create mask raster
        mask_data = np.where(self._raster_data == self._no_data, 0.0, 1.0)
        mask_raster_file = gdal_runner.get_temp_file('mask_raster.tif')
        self._create_raster(mask_raster_file, mask_data)

        # create proximity raster
        args = [mask_raster_file, '$OUT_FILE$', '-distunits', 'GEO', '-maxdist', f'{extrapolation_width}',
                '-fixed-buf-val', '0', '-nodata', '1']
        proximity_file = gdal_runner.run('gdal_proximity.py', 'proximity.tif', args)
        ri = RasterInput(proximity_file)
        proximity_array = ri.get_raster_values()
        ri = None

        # get locations to interpolate to including inside and outside of ugrid
        to_include = 0
        delta = np.append(self._size_delta, 0.0)
        xy_min = self._min_pt + delta / 2
        indices = [(self._height - row - 1) * self._width + col
                   for row, row_mask in enumerate(proximity_array)
                   for col, mask in enumerate(row_mask)
                   if mask == to_include]
        locations = [xy_min + delta * (col, self._height - row - 1, 0)
                     for row, row_mask in enumerate(proximity_array)
                     for col, mask in enumerate(row_mask)
                     if mask == to_include]

        if not len(locations) == 0:
            # extrapolate using IDW
            pts = self._ugrid.locations
            interp = InterpIdw(pts)
            interp_values = interp.interpolate_to_points(locations)

            # update raster values
            self._raster_data = np.full(self._width * self._height, self._no_data)
            self._raster_data[indices] = interp_values

        # Remove the folder containing any generated files
        shutil.rmtree(gdal_runner.temp_file_path)

    def _create_raster(self, filename, raster_array):
        """Create the raster file with interpolated elevations.

        Args:
            filename (str): Path to the output file to write
            raster_array (numpy.array): The raster data.
        """
        raster_output = RasterOutput(xorigin=self._min_pt[0], yorigin=self._max_pt[1], width=self._width,
                                     height=self._height, pixel_width=self._size_delta[0],
                                     pixel_height=self._size_delta[1], nodata_value=self._no_data, wkt=self._wkt)
        raster_output.write_raster(filename, raster_array)

    def generate_raster(self, filename: str, cogrid, ugrid, min_pt, max_pt, width, height, wkt, ts_data=None,
                        no_data=-999.0, extrapolation_width=0.0, specify_wse=False,
                        extrapolate_all=False, template_raster=None, template_activity=False,
                        template_option=''):
        """Creates a raster and interpolates elevation values of the source geometry.

        Args:
            filename (str): Path to the output file to write
            cogrid (xms.constraint.Grid): The constrained grid
            ugrid (xms.grid.ugrid.UGrid): 2d Unstructured grid
            min_pt (tuple of float): The coordinate of the lower left corner of the lower left cell
            max_pt (tuple of float): The coordinate of the upper right corner of the upper right cell
            width (int): Width of the output raster in cells
            height (int): Height of the output raster in cells
            wkt (str): Well-known text of the geometry's projection
            ts_data (list): Data values for each point on the UGrid associated with this class
            no_data (float): NODATA value for values specified in ts_data
            extrapolation_width (float): The distance to extrapolate out from the UGrid
            specify_wse (bool): Whether to specify WSE and elevation raster for computing active raster region
            extrapolate_all (bool): Whether to extrapolate all the values for the raster using IDW interpolation
            template_raster (RasterInput): Template raster for raster creation
            template_activity (bool): Flag for using activity on the template raster
            template_option (str): The string specifying whether to use a native template raster resolution
        """
        self._ugrid = ugrid
        self._width = width
        self._height = height
        self._wkt = wkt
        self._no_data = no_data
        self._points = self._ugrid.locations
        self._min_pt = np.array(min_pt)
        self._max_pt = np.array(max_pt)
        self._extrapolate = extrapolation_width != 0.0
        self._raster_data = np.full(self._width * self._height, self._no_data)
        self._specify_wse = specify_wse
        self._elev_values = None
        self._wse_values = None
        self._elev_nodata = -999999.0
        self._raster_activity = np.array([])

        # If using the template raster, but a specified alternate resolution, warp a new template
        # raster with the new resolution.
        if template_raster is not None and template_option == 'Specify maximum XY resolution':
            # Resample the template raster to the new resolution specified by the tool
            resampled_file = filesystem.temp_filename(suffix='.tif')
            gw.gdal_warp(template_raster.gdal_raster, resampled_file, width=width, height=height,
                         resampleAlg=RasterReproject.GRA_Cubic)
            active_file = resampled_file

            # Update the template raster used from here on out
            template_raster = None
            template_raster = RasterInput(resampled_file)

        if template_raster is not None and template_activity:
            # We are using a template raster, and using its active area as the active region
            # Get the raster values, and reclassify as either 1 (active) or 0 (nodata)
            raster_values = template_raster.get_raster_values()
            nodata = template_raster.nodata_value
            active_values = np.select([raster_values != nodata, np.isclose(raster_values, nodata)], [1, 0])

            # Write to the raster storing the active value cells
            active_file = '/vsimem/active_region.tif'
            # active_file = filesystem.temp_filename(suffix='.tif')
            # active_file = 'c:/temp/active_file.tif'
            ro = RasterOutput(width=template_raster.resolution[0], height=template_raster.resolution[1],
                              wkt=template_raster.wkt, data_type=RasterInput.GDT_Int32,
                              geotransform=template_raster.geotransform)
            ro.write_raster(active_file, active_values)
        else:
            # We are using the "on" cells of the grid to determine the active region
            # Get some information about the template raster OR grid
            if template_raster is not None:
                active_wkt = wkt
                active_geostransform = template_raster.geotransform
            else:
                col_size = (max_pt[0] - min_pt[0]) / width
                row_size = (max_pt[1] - min_pt[1]) / height
                active_wkt = wkt
                active_geostransform = [min_pt[0], col_size, 0.0, max_pt[1], 0.0, -row_size]

            on_off = cogrid.model_on_off_cells if len(cogrid.model_on_off_cells) > 0 else [1] * ugrid.cell_count
            ds_vals = on_off
            poly_builder = GridCellToPolygonCoverageBuilder(co_grid=cogrid,
                                                            dataset_values=ds_vals,
                                                            wkt=None, coverage_name='temp')
            # Set up a vector polygon layer for storing polygons for each grid cell
            active_geom_file = '/vsimem/active_polygons.shp'
            # active_geom_file = filesystem.temp_filename(suffix='.shp')
            # active_geom_file = 'C:/temp/active_geom.shp'
            vo = VectorOutput()
            vo.initialize_file(active_geom_file, active_wkt)
            ug_locs = ugrid.locations
            out_polys = poly_builder.find_polygons()
            if 1 in out_polys:
                for cur_poly in out_polys[1]:
                    poly_pts = cur_poly[0]
                    points = [[ug_locs[i][0], ug_locs[i][1], ug_locs[i][2]] for i in poly_pts]
                    holes = []
                    for k in range(1, len(cur_poly)):
                        poly_pts = cur_poly[k]
                        holes.append([[ug_locs[i][0], ug_locs[i][1], ug_locs[i][2]] for i in poly_pts])
                    vo.write_polygon(points, holes)
            # Close the handles associated with the polygon shapefile
            vo = None
            # Next, use gdal RasterizeLayer to burn in the active polygon region into a raster
            # Write to the raster storing the active value cells
            active_file = '/vsimem/active_region.tif'
            # active_file = filesystem.temp_filename(suffix='.tif')
            # active_file = 'c:/temp/active_file.tif'
            # Set the nodata_value to 0 so we get 0 values for areas outside the polygon. We need 0 values outside the
            # polygon (not self._no_data) when calling the interp_util functions.
            ro = RasterOutput(width=width, height=height, nodata_value=0, wkt=active_wkt,
                              data_type=RasterInput.GDT_Int32, geotransform=active_geostransform)
            active_ds = ro.write_template_raster(active_file)

            # Rasterize the polygon geometry onto the active raster
            gw.gdal_rasterize(active_geom_file, active_ds, burnValues=[1], bands=[1])
            active_ds = None  # Set to None to close and finish writing

            # Remove the in-memory shapefile
            gu.delete_vector_file(active_geom_file)

        # At this point, we have a raster of 1's and 0's, with 1's being the active region
        ri = RasterInput(active_file)
        raster_values = ri.get_raster_values()
        ri = None

        # Interpolate/extrapolate
        self._raster_activity = raster_values  # Activity values from activity raster
        if extrapolate_all:
            self.logger.info('Interpolating and extrapolating WSE values to the raster cells...')
            self._extrapolate_all(ts_data, self._no_data)
        else:
            self.logger.info('Interpolating values to the raster cells...')
            self._interp_to_raster(ts_data, self._no_data, False)
        self.logger.info('Writing raster file...')
        if self._extrapolate:
            self._extrapolate_raster_data(extrapolation_width)
        self._create_raster(filename, self._raster_data)

        # Cleanup:  remove the temporary data
        # gdal.Unlink(active_file)

# def main():
#     """Main function, for testing."""
#     pass
#     import os
#     from xms.tool_gui.tool_dialog import ToolDialog
#     from xms.guipy.dialogs.xms_parent_dlg import ensure_qapplication_exists
#     from xms.tool.utilities.test_utils import get_test_files_path
#     qapp = ensure_qapplication_exists()
#     tool = RasterFromDatasetTool()
#     tool.set_gui_data_folder(os.path.join(get_test_files_path(), 'raster_from_dataset_tool'))
#     arguments = tool.initial_arguments()
#     grid_path = os.path.join(get_test_files_path(), 'raster_from_dataset_tool', 'simple_2d_grid.xmc')
#     tool.set_grid_uuid_for_testing(grid_path, 'bbb6f43c-c76a-4e00-8bec-474b2ff19cb0')
#     arguments[ARG_PIXEL_SIZE].value = 0.1
#     arguments[ARG_NULL_VALUE].value = -9999999.0
#     arguments[ARG_EXTRAPOLATION_WIDTH].value = 10.0
#     arguments[ARG_OUTPUT_RASTER].value = os.path.join(get_test_files_path(), 'raster_from_dataset_tool',
#                                                       'simple_2d_grid.tif')
#     wkt = (
#         'GEOGCS["GCS_WGS_1984",'
#         'DATUM["WGS84",SPHEROID["WGS84",6378137,298.257223563]],'
#         'PRIMEM["Greenwich",0],UNIT["Degree",0.017453292519943295]]'
#     )
#     tool.default_wkt = wkt
#
#     tool_dialog = ToolDialog(None, arguments, tool.name, tool=tool)
#     if tool_dialog.exec():
#         tool.run_tool(tool_dialog.tool_arguments)
#     qapp = None
#
#
# if __name__ == "__main__":
#     main()
