"""CMS-Flow Hard bottom tool."""

# 1. Standard Python modules

# 2. Third party modules

# 3. Aquaveo modules
from xms.datasets.dataset_writer import DatasetWriter
from xms.guipy.data.target_type import TargetType
from xms.snap.snap_polygon import SnapPolygon
from xms.tool_core import IoDirection, Tool
from xms.tool_core.argument import Argument

# 4. Local modules
from xms.cmsflow.components.hb_component import HardBottomCoverageComponent

ARG_INPUT_UGRID = 0
ARG_INPUT_COVERAGE = 1
ARG_OUTPUT_DATASET = 2


class HardBottomDatasetTool(Tool):
    """
    Creates Hard Bottom dataset input for CMS-Flow.
    """
    def __init__(self):
        """Initializes the class."""
        super(HardBottomDatasetTool, self).__init__(name='Generate CMS-Flow Hard Bottom Dataset')
        self._co_grid = None
        self._hb_cov = None
        self._hb_comp = None
        self._hb_dataset = None
        self._poly_snap = None
        self._out_ds_name = None
        self._poly_data = None
        self._query = None

    def set_data_handler(self, data_handler):
        """Set up query attribute if we have a XMSDataHandler."""
        super().set_data_handler(data_handler)
        if hasattr(self._data_handler, "_query"):
            self._query = self._data_handler._query

    def _component_main_file(self):
        """Get the component data from the coverage."""
        do_comp = self._query.item_with_uuid(
            self._hb_cov.attrs['uuid'], model_name='CMS-Flow', unique_name='HardBottomCoverageComponent'
        )
        if not do_comp.main_file:
            raise RuntimeError('Selected coverage must be a CMS-Flow Hard Bottom Coverage. Aborting.')
        return do_comp.main_file

    def initial_arguments(self):
        """Get the initial arguments for the tool (MUST OVERRIDE).

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.grid_argument(name='Input UGrid', description='Target UGrid', io_direction=IoDirection.INPUT),
            self.coverage_argument(
                name='input_hard_bottom', description='Input hard bottom coverage containing polygons'
            ),
            self.dataset_argument(
                name='hb_dataset',
                description='Hard bottom dataset name',
                value='hard_bottom',
                io_direction=IoDirection.OUTPUT
            ),
        ]
        return arguments

    def run(self, arguments: list[Argument]):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        self.logger.info('Running Hard Bottom Tool')
        self._co_grid = self.get_input_grid(arguments[ARG_INPUT_UGRID].text_value)
        if self._co_grid.cell_elevations is None:
            raise RuntimeError('Selected grid must have cell elevations. Aborting.')
        self._hb_cov = self.get_input_coverage(arguments[ARG_INPUT_COVERAGE].text_value)
        self._out_ds_name = arguments[ARG_OUTPUT_DATASET].text_value
        self._get_hard_bottom_component_data()
        self._generate_hard_bottom_dataset()

    def _load_component_ids(self):
        """Load the component ids."""
        self._query.load_component_ids(self._hb_comp, polygons=True)

    def _get_hard_bottom_component_data(self):
        """Get the component data from the coverage."""
        self.logger.info('Loading hard bottom attributes...')
        main_file = self._component_main_file()
        self._hb_comp = HardBottomCoverageComponent(main_file)
        self._load_component_ids()

        self._poly_snap = SnapPolygon()
        self._poly_snap.set_grid(self._co_grid, target_cells=True)
        self._poly_snap.add_gdf_polygons(self._hb_cov)

        pp = self._hb_comp.generic_model().polygon_parameters
        hb_types = pp.group_names
        self._poly_data = []
        polygons = self._hb_cov[self._hb_cov['geometry_types'] == 'Polygon']
        for poly in polygons.itertuples():
            comp_id = self._hb_comp.get_comp_id(TargetType.polygon, poly.id)
            if comp_id is None or comp_id < 0:
                continue
            cell_idxs = self._poly_snap.get_cells_in_polygon(poly.id)
            if len(cell_idxs) < 1:
                continue
            hb_type, hb_val = self._hb_comp.data.feature_type_values(TargetType.polygon, comp_id)
            if hb_type in hb_types:
                self._poly_data.append((poly.id, cell_idxs, hb_type, hb_val))

    def _generate_hard_bottom_dataset(self):
        """Generate the hard bottom dataset."""
        self.logger.info('Generating hard bottom dataset...')
        # default the dataset to -999.0 (no hard bottom)
        ug = self._co_grid.ugrid
        ds = [-999.0] * ug.cell_count
        cell_elev = self._co_grid.cell_elevations

        # loop through polygons and set the dataset value
        pp = self._hb_comp.generic_model().polygon_parameters
        for poly in self._poly_data:
            self.logger.info(f'Processing polygon id: {poly[0]}')
            cell_idxs = poly[1]
            hb_type = poly[2]
            pp.restore_values(poly[3])
            grp = pp.group(hb_type)
            if hb_type == 'non_erodible_cell':
                for idx in cell_idxs:
                    ds[idx] = cell_elev[idx]
            elif hb_type == 'specified_distance_erodible_cell':
                distance = grp.parameter('distance_below_elevation').value
                for idx in cell_idxs:
                    ds[idx] = cell_elev[idx] - distance
            elif hb_type == 'specified_elevation_erodible_cell':
                elevation = grp.parameter('specified_elevation').value
                for idx in cell_idxs:
                    ds[idx] = elevation

        # save the dataset
        ds_writer = DatasetWriter(name=self._out_ds_name, geom_uuid=self._co_grid.uuid, location='cells')
        ds_writer.append_timestep(0.0, ds)
        ds_writer.appending_finished()
        self.set_output_dataset(ds_writer)
