"""LeveeGroundElevationsTool class."""

__copyright__ = "(C) Copyright Aquaveo 2020"
__license__ = "All rights reserved"

# 1. Standard Python modules
import uuid

# 2. Third party modules

# 3. Aquaveo modules
from xms.core.filesystem import filesystem as xfs
from xms.datasets.dataset_writer import DatasetWriter
from xms.tool_core import ALLOW_ONLY_COVERAGE_TYPE, ALLOW_ONLY_MODEL_NAME, IoDirection

# 4. Local modules
from xms.adcirc.tools.levee_check_tool_base import LeveeCheckToolBase
from xms.adcirc.tools.levee_check_tool_consts import TOOL_TYPE_GROUND_ELEVATION


class LeveeGroundElevationsTool(LeveeCheckToolBase):
    """Tool to check and adjust levee pair mesh ground elevations against the defined levee crest elevation curve."""
    ARG_INPUT_MIN_DELTA_Z = 2
    ARG_OUTPUT_ADJUSTED_Z = 3
    ARG_OUTPUT_Z_DIFF = 4

    def __init__(self):
        """Initializes the class."""
        super().__init__(TOOL_TYPE_GROUND_ELEVATION)
        # Store data for adjusted levees so we can plot them after we are done
        self._results.update(
            {
                'Crest Elevation': [],  # Crest elevation defined on the levee
                'Side 1 Adjusted': [],  # Adjusted ground elevation on first levee arc side.
                'Side 1 Original': [
                ],  # Original ground elevation on first levee arc side. Same as adjusted if unchanged.
                'Side 2 Adjusted': [],  # Adjusted ground elevation on second levee arc side.
                'Side 2 Original': [
                ],  # Original ground elevation on second levee arc side. Same as adjusted if unchanged.
            }
        )
        self._reset_candidate_levee()  # Initialize the temp storage
        self.ground_elevations = None  # Output array of node elevations
        self.elevation_diff = None  # Output array of adjusted elevation - original Z diff, only stored for testing

    def _get_inputs(self, arguments):
        """Get the inputs to the tool from XMS.

        Args:
            arguments (:obj:`list`): The tool arguments.

        Returns:
            (:obj:`tuple(xarray.Dataset,numpy.array)`): The levee dataset and the levee comp_ids. None to indicate
            there were no levees in the input BC coverage or no levee types selected and checks should not continue.
        """
        levee_data, levee_comp_ids = super()._get_inputs(arguments)
        if levee_comp_ids:  # Don't bother setting up output dataset arrays if no levees to check
            self._threshold = arguments[self.ARG_INPUT_MIN_DELTA_Z].value
            # Initialize the output ground elevation array from the input grid Zs. Make a copy for the diff output.
            self.ground_elevations = self._cogrid.ugrid.locations[:, -1]
            self.original_z = self.ground_elevations.copy()
        return levee_data, levee_comp_ids

    def _check_levee(self, snap_data):
        """Map levee coverage data attributes to snapped node locations, updating crest elevation if needed.

        Args:
            snap_data (:obj:`dict`): The stored snap data for the levee from the first pass

        Returns:
            (:obj:`tuple(float,float,list[str])`): Minimum crest elevation of levee at nodes, maximum crest
            elevation of levee at nodes, list of the adjusted ground elevations for reporting to the log window
        """
        min_z = float('inf')  # Range of levee crest elevations for report summary
        max_z = float('-inf')
        adjustments = []  # Reports of the adjusted ground elevations
        # Unpack the applied data we stored in the first pass of checks
        t_lens = snap_data['t_length']
        zcrests = snap_data['zcrest']
        node1_idxs = snap_data['snap1']['id']
        node2_idxs = snap_data['snap2']['id']
        node1_locs = snap_data['snap1']['location']
        node2_locs = snap_data['snap2']['location']
        # Map the levee's attribute curves to each node pair
        for t_len, crest_elevation, node1_idx, node2_idx, node1_loc, node2_loc in zip(
            t_lens, zcrests, node1_idxs, node2_idxs, node1_locs, node2_locs
        ):
            min_z = min(min_z, crest_elevation)  # Update min/max crest elevations for summary report
            max_z = max(max_z, crest_elevation)
            # Check if ground elevations are below the crest elevation (by at least the height of user defined min)
            elev1 = node1_loc[2]
            elev2 = node2_loc[2]
            new_elev1 = elev1
            new_elev2 = elev2
            if (crest_elevation - elev1) < self._threshold:  # Update dataset value for first node in pair
                new_elev1 = crest_elevation - self._threshold
                self.ground_elevations[node1_idx] = new_elev1
            if (crest_elevation - elev2) < self._threshold:  # Update dataset value for second node in pair
                new_elev2 = crest_elevation - self._threshold
                self.ground_elevations[node2_idx] = new_elev2
            if new_elev1 != elev1 or new_elev2 != elev2:  # Store adjustment report if adjusted
                # node1ID, node2ID, new_elev1, orig_elev1, new_elev2, orig_elev2
                adjustments.append(
                    f'{node1_idx + 1:<12}{node2_idx + 1:<12}{new_elev1:<25.6f}{elev1:<25.6f}'
                    f'{new_elev2:<25.6f}{elev2:<.6f}'
                )
            self._store_node_pair_row(t_len, crest_elevation, new_elev1, elev1, new_elev2, elev2)
        return min_z, max_z, adjustments

    def _create_adjusted_z_dataset(self, dset_name, diff_name):
        """Create the new ground elevation dataset to send back to SMS.

        Args:
            dset_name (:obj:`str`): Name of the adjusted name
            diff_name (:obj:`str`): Name of the adjusted-original diff dataset
        """
        if not self._num_adjusted and not self._num_adjusted_issue:
            self.logger.info(
                'No ground elevation nodes were below the crest elevation threshold. No new dataset will '
                'be created.'
            )
            return  # No updated output to send to SMS
        self.logger.info('Creating new ground elevation dataset...')
        dset = DatasetWriter(
            h5_filename=xfs.temp_filename(), name=dset_name, dset_uuid=str(uuid.uuid4()), geom_uuid=self._cogrid.uuid
        )
        dset.write_xmdf_dataset([0.0], [self.ground_elevations])
        self.set_output_dataset(dset)
        self.logger.info('Creating adjusted-original elevation dataset...')
        # Now create the adjusted-original diff dataset
        dset_diff = DatasetWriter(
            h5_filename=xfs.temp_filename(), name=diff_name, dset_uuid=str(uuid.uuid4()), geom_uuid=self._cogrid.uuid
        )
        self.elevation_diff = self.ground_elevations - self.original_z  # Store the diff so we can check it in tests
        dset_diff.write_xmdf_dataset([0.0], [self.elevation_diff])
        self.set_output_dataset(dset_diff)

    def _run_second_pass(self):
        """Run the second pass of checks if there were no fatal errors on the first pass."""
        self.logger.info('Comparing levee crest elevations to ground elevations at mesh nodes.')
        for comp_id, snap_data in self._snap_data.items():
            num_pairs = 0
            adjustments = []  # Reports of the adjusted ground elevations
            min_z = float('inf')  # Range of levee crest elevations for report summary
            max_z = float('-inf')
            arc_ids = snap_data['arc_id']
            if snap_data['snap1'] is not None:  # Don't continue checking if error occurred during snap
                min_z, max_z, adjustments = self._check_levee(snap_data)  # Check the ground elevations at this levee
                num_pairs = len(snap_data['snap1']['id'])
            if arc_ids:  # Don't report non-existent levees
                self._report_levee_check(comp_id, arc_ids[0], arc_ids[1], adjustments, num_pairs, min_z, max_z)

    def initial_arguments(self):
        """Get initial arguments for tool - must override.

        Returns:
            (:obj:`list`): A list of the initial tool arguments.
        """
        bc_filter = {ALLOW_ONLY_MODEL_NAME: 'ADCIRC', ALLOW_ONLY_COVERAGE_TYPE: 'Boundary Conditions'}
        arguments = [
            self.coverage_argument(
                name='input_bc_coverage', description='Input ADCIRC Boundary Conditions coverage', filters=bc_filter
            ),
            self.grid_argument(name='domain_grid', description='Domain grid'),
            self.float_argument(
                name='min_crest_deltaz',
                description='Minimum height of crest elevation above ground',
                value=0.0,
                min_value=0.0
            ),
            self.dataset_argument(
                name='output_dataset',
                description='Output dataset',
                value='Adjusted Elevation',
                io_direction=IoDirection.OUTPUT
            ),
            self.dataset_argument(
                name='output_dataset_diff',
                description='Output diff dataset',
                value='Adjusted Elevation - Z',
                io_direction=IoDirection.OUTPUT
            ),
        ]
        return arguments

    def run(self, arguments):
        """Run the tool - must override.

        Args:
            arguments (:obj:`list`): The tool arguments.
        """
        if not self._map_all_levees_for_check(arguments):  # This applies levees to mesh and runs first pass of checks
            return  # Fatal error encountered and logged on the first pass, abort checks
        self._run_second_pass()  # Run levee ground elevation checks
        self._create_adjusted_z_dataset(
            arguments[self.ARG_OUTPUT_ADJUSTED_Z].value, arguments[self.ARG_OUTPUT_Z_DIFF].value
        )
        self._report_results()  # Print summary to log window and write results data for the results dialog
