"""ArcsFromDatasetCountoursTool whitebox tool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.constraint.contour import UGrid2dContour
from xms.constraint.ugrid_activity import CellToPointActivityCalculator
from xms.tool_core import ALLOW_ONLY_SCALARS, Argument, IoDirection, Tool
from xms.tool_core.table_definition import FloatColumnType, TableDefinition

# 4. Local modules
from xms.tool.utilities.coverage_conversion import convert_lines_to_coverage


def _create_table_definition():
    """Creates a blank/default table definition."""
    columns = [
        FloatColumnType(header='Value', tool_tip='Contour interval value', default=0.0)
    ]
    return TableDefinition(columns)


class ArcsFromDatasetCountoursTool(Tool):
    """ArcsFromDatasetCountoursTool whitebox tool class."""
    ARG_INPUT_DATASET = 0
    ARG_INPUT_TIMESTEP = 1
    ARG_CONTOUR_VALUES = 2
    ARG_OUTPUT_COVERAGE = 3
    VALUE_HEADER = 'Value'
    NO_DATA = -999999.0

    def __init__(self):
        """Initializes the class."""
        super().__init__('Arcs from Dataset Contours')

    def initial_arguments(self):
        """Get initial arguments for tool.

        Returns:
            (list): A list of the initial tool arguments.
        """
        table_def = _create_table_definition()
        df = table_def.to_pandas()
        arguments = [
            self.dataset_argument(name='input_dataset', description='Input dataset', filters=ALLOW_ONLY_SCALARS),
            self.timestep_argument(name='time_step', description='Timestep', value=Argument.NONE_SELECTED),
            self.table_argument(name='contour_values', description='Contour values', value=df, optional=False,
                                table_definition=table_def),
            self.coverage_argument(name='output_coverage', description='Output coverage',
                                   io_direction=IoDirection.OUTPUT)
        ]
        return arguments

    def enable_arguments(self, arguments):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.
        """
        arguments[self.ARG_INPUT_TIMESTEP].enable_timestep(arguments[self.ARG_INPUT_DATASET])

    def _ts_data(self, ugrid, dset, timestep):
        """Get the data scalar values for the requested timestep, with all activity converted to null values.

        Args:
            ugrid (UGrid): The input dataset's UGrid
            dset (DatasetReader): The input dataset
            timestep (int): The timestep to get

        Returns:
            (np.ndarray): The values for the timestep with all inactive values set to self.NO_DATA
        """
        if dset.num_activity_values is not None and dset.num_activity_values != dset.num_values:
            dset.activity_calculator = CellToPointActivityCalculator(ugrid)  # Point dataset with cell activity
        ts_data, _ = dset.timestep_with_activity(timestep, nan_null_values=True, nan_activity=True)
        # If the dataset had a null value or activity array, all inactive values are now nan. Set them to a no data
        # value for the contourer.
        ts_data[np.isnan(ts_data)] = self.NO_DATA
        return ts_data

    def _contourer(self, ugrid, dset, ts_data, intervals):
        """Build a contourer.

        Args:
            ugrid (UGrid): The input dataset's UGrid
            dset (DatasetReader): The input dataset
            ts_data (np.ndarray): Scalar data values for the timestep
            intervals (np.ndarray): The contour values to extract

        Returns:
            (UGrid2dContour): The contourer
        """
        contours = UGrid2dContour(ugrid)
        contours.set_no_data_value(self.NO_DATA)
        contours.set_extract_scalars(np.sort(intervals).tolist())
        contours.set_contour_length_threshold(0.0)
        if dset.location == 'cells':
            contours.set_grid_cell_scalars(ts_data, [], 1)  # DataLocationEnum::LOC_CELLS == 1
        else:
            contours.set_grid_point_scalars(ts_data, [], 0)  # DataLocationEnum::LOC_POINTS == 0
        return contours

    def _build_cov(self, contour_lines, cov_name):
        """Create the output coverage.

        Args:
            contour_lines (dict): List of segments at the specified values
            cov_name (str): Name of the output coverage

        Returns:
            (GeoDataFrame): A GeoDataFrame containing the given lines
        """
        arc_lines = []
        for z_val, segments in contour_lines.items():
            for segment in segments:
                arc_lines.append([(pt[0], pt[1], z_val) for pt in segment])
        return convert_lines_to_coverage(arc_lines, cov_name, self.default_wkt)

    def run(self, arguments):
        """Runs the tool.

        Args:
            arguments (list): A list of the tool's arguments
        """
        # Get the input dataset and geometry.
        timestep = arguments[self.ARG_INPUT_TIMESTEP].get_timestep(arguments[self.ARG_INPUT_DATASET])
        dset = self.get_input_dataset(arguments[self.ARG_INPUT_DATASET].text_value)
        co_grid = self.get_input_dataset_grid(arguments[self.ARG_INPUT_DATASET].text_value)
        ugrid = co_grid.ugrid
        ts_data = self._ts_data(ugrid, dset, timestep)
        contourer = self._contourer(ugrid, dset, ts_data, arguments[self.ARG_CONTOUR_VALUES].value[self.VALUE_HEADER])
        contour_lines = contourer.extract_contour_segments()
        new_cov = self._build_cov(contour_lines, arguments[self.ARG_OUTPUT_COVERAGE].value)
        if new_cov is not None and not new_cov.empty:
            self.set_output_coverage(new_cov, arguments[self.ARG_OUTPUT_COVERAGE])
        else:
            self.fail('No contours extracted at specified values.')
