"""Class for updating adh friction."""

# 1. Standard Python modules
from dataclasses import dataclass
import logging
import os
import shutil
import uuid

# 2. Third party modules
from adhparam.material_properties import MaterialProperties
import pandas as pd
from shapely import geometry

# 3. Aquaveo modules
from xms.adh.components.material_conceptual_component import MaterialConceptualComponent
from xms.components.display.display_options_io import read_display_options_from_json
from xms.components.display.xms_display_message import XmsDisplayMessage
from xms.coverage.grid.polygon_coverage_builder import PolygonCoverageBuilder
from xms.coverage.polygons.polygon_orienteer import get_polygon_point_lists
from xms.data_objects.parameters import Coverage, FilterLocation
from xms.guipy.data.category_display_option_list import CategoryDisplayOptionList
from xms.guipy.data.target_type import TargetType

# 4. Local modules


@dataclass
class ShapelyPolygonInfo:
    """Class for keeping track of polygon info while merging coverages."""
    id: int
    polygon: geometry.Polygon
    comp_id: int = None
    ewn_id: int = None
    adh_id: int = None

    @property
    def locations(self):
        """Get the locations of the shapely polygon."""
        locations = self.polygon.exterior.coords[:]
        for inner in self.polygon.interiors:
            locations.extend(inner.coords[:])
        return set(locations)


class AdhFrictionUpdater:
    """The AdH Friction updater class."""
    def __init__(self, target_cov, source_cov):
        """Initializes the helper class.

        Args:
            target_cov: The target coverge
            source_cov: The source coverage
        """
        self._logger = logging.getLogger('xms.ewn')

        self._adh_materials_coverage = target_cov
        self._ewn_coverage = source_cov
        self._out_adh_mat_coverage = source_cov

        self._new_mat_comp = None
        self._new_mat_geom = None

        self._mat_component_ids = None
        self._out_cov_uuid = None
        self._ewn_poly_to_new_mat_id = None
        self._polygon_merge_coverage = None
        self._merge_polygon_mapper = None  # (poly_id, ewn_poly_id) -> merged_cov_poly_id

    @property
    def out_adh_mat_coverage(self):
        """The output adh materials coverage."""
        return {
            'cov_geom': self._new_mat_geom,
            'cov_comp': self._new_mat_comp,
        }

    def _create_new_adh_materials_coverage(self):
        self._logger.info('Creating new ADH materials coverage')
        self._copy_adh_materials_component()
        old_ids = list(self._new_mat_comp.data.materials.material_properties.keys())
        self._setup_new_adh_materials_coverage()
        new_ids = list(self._new_mat_comp.data.materials.material_properties.keys())
        _ = self._new_mat_comp.update_display_id_files(old_ids, new_ids)
        self._new_mat_comp.update_display_options_file()
        self._new_mat_comp.display_option_list.append(
            XmsDisplayMessage(file=self._new_mat_comp.disp_opts_file, edit_uuid=self._new_mat_comp.cov_uuid)
        )
        self._new_mat_comp.data.commit()

    def _copy_adh_materials_component(self):
        """Copy AdH materials component for new coverage."""
        adh_component = self._adh_materials_coverage[0][1]
        components_dir = os.path.dirname(os.path.dirname(adh_component.main_file))
        component_uuid = str(uuid.uuid4())
        new_dir = os.path.join(components_dir, component_uuid)
        shutil.copytree(os.path.dirname(adh_component.main_file), new_dir)
        adh_component.save_to_location(new_dir, 'DUPLICATE')
        new_main_file = os.path.join(new_dir, os.path.basename(adh_component.main_file))
        self._new_mat_comp = MaterialConceptualComponent(new_main_file)
        self._new_mat_comp.comp_to_xms = adh_component.comp_to_xms

    def _setup_new_adh_materials_coverage(self):
        """Set up a new AdH materials coverage."""
        adh_coverage = self._adh_materials_coverage[0][0]
        ewn_coverage = self._ewn_coverage[0][0]

        merge_info = build_stamped_coverage(adh_coverage, ewn_coverage, self._logger)
        self._polygon_merge_coverage, self._merge_polygon_mapper = merge_info

        if self._out_cov_uuid is None:
            self._out_cov_uuid = str(uuid.uuid4())
        new_cov = Coverage(
            name=f'{adh_coverage.name}-updated', project=adh_coverage.projection, uuid=self._out_cov_uuid
        )
        self._new_mat_comp.cov_uuid = self._out_cov_uuid
        self._new_mat_comp.data.info.attrs['cov_uuid'] = self._out_cov_uuid
        self._new_mat_comp.data.commit()  # save the coverage uuid

        new_cov.set_points(adh_coverage.get_points(FilterLocation.PT_LOC_DISJOINT))

        self._add_new_ewn_materials_to_coverage()
        self._mat_component_ids = []

        ewn_id_mapper = {}
        max_poly_id = max([x.id for x in adh_coverage.polygons])
        merge_coverage_poly_mapper = {p.id: p for p in self._polygon_merge_coverage.polygons}
        for id, polygon_ids in self._merge_polygon_mapper.items():
            new_polygon_id = polygon_ids[0]
            adh_id, ewn_id = id
            new_polygon = merge_coverage_poly_mapper[new_polygon_id]
            if adh_id is not None:
                new_polygon.id = adh_id
            elif ewn_id is not None:
                max_poly_id += 1
                new_polygon.id = max_poly_id
                ewn_id_mapper[max_poly_id] = ewn_id
        new_cov.polygons = self._polygon_merge_coverage.polygons
        new_cov.arcs = self._polygon_merge_coverage.arcs

        mat_poly_ids = [x.id for x in new_cov.polygons]
        cov_dict = {}
        if adh_coverage.uuid in self._new_mat_comp.comp_to_xms:
            cov_dict = self._new_mat_comp.comp_to_xms[adh_coverage.uuid]
        if len(cov_dict) < 1:
            cov_dict = {TargetType.polygon: {}}
        self._new_mat_comp.update_ids[self._out_cov_uuid] = {}
        for target_type, comp_id_dict in cov_dict.items():
            self._new_mat_comp.update_ids[self._out_cov_uuid][target_type] = {}
            for comp_id, att_ids in comp_id_dict.items():
                for att_id in att_ids:
                    self._new_mat_comp.update_ids[self._out_cov_uuid][target_type][att_id] = comp_id
                    if comp_id in mat_poly_ids:
                        self._mat_component_ids.append(comp_id)

        update = self._new_mat_comp.update_ids
        for new_poly_id, old_poly_id in ewn_id_mapper.items():
            if old_poly_id not in self._ewn_poly_to_new_mat_id:
                continue
            comp_id = self._ewn_poly_to_new_mat_id[old_poly_id]
            update[self._out_cov_uuid][TargetType.polygon][new_poly_id] = comp_id
            self._mat_component_ids.append(comp_id)

        if TargetType.polygon not in self._new_mat_comp.update_ids[self._out_cov_uuid]:
            self._new_mat_comp.update_ids[self._out_cov_uuid][TargetType.polygon] = {}

        new_cov.complete()
        self._new_mat_geom = new_cov

    def _add_new_ewn_materials_to_coverage(self):
        self._logger.info('Adding new materials to materials list...')
        mats = self._new_mat_comp.data.materials
        ewn_component = self._ewn_coverage[0][1]
        json_disp_dict = read_display_options_from_json(ewn_component.disp_opts_file)
        cat_display_option_list = CategoryDisplayOptionList()
        cat_display_option_list.from_dict(json_disp_dict)

        cov_uuid = ewn_component.cov_uuid
        comp_id_to_poly_id = ewn_component.comp_to_xms[cov_uuid][TargetType.polygon]
        comp_ids = ewn_component.data.polygons.comp_id.values.tolist()
        manning_ns = ewn_component.data.polygons.manning_n.values.tolist()
        classifications = ewn_component.data.polygons.classification.values.tolist()
        polys = zip(*[classifications, manning_ns, comp_ids])
        old_poly_id_to_new_mat_id = {}
        new_mat_dict = {}
        for poly in polys:
            mat_classify, new_manning_n, old_comp_id = poly
            key = (mat_classify, new_manning_n)
            old_poly_id = comp_id_to_poly_id.get(old_comp_id, [-1])[0]
            if old_poly_id < 0:
                continue

            if key not in new_mat_dict:
                new_mat = MaterialProperties()
                mat_id = mats.add_material(new_mat)
                new_mat_dict[key] = mat_id
                new_mat_friction_df = pd.DataFrame(
                    {
                        'CARD': 'FR',
                        'CARD_2': 'MNG',
                        'STRING_ID': float(mat_id),
                        'REAL_01': float(new_manning_n),
                        'REAL_02': float('nan'),
                        'REAL_03': float('nan'),
                        'REAL_04': float('nan'),
                        'REAL_05': float('nan')
                    },
                    index=[len(mats.friction)]
                )
                mats.friction = pd.concat([mats.friction, new_mat_friction_df])
                mat_class_index = int(mat_classify - 1)
                mats.material_properties[mat_id].material_name =\
                    ewn_component.data.meta_data['Feature type/Region'].tolist()[mat_class_index]
                mats.material_display[mat_id] = cat_display_option_list.categories[int(mat_classify)].options
            old_poly_id_to_new_mat_id[old_poly_id] = new_mat_dict[key]

        self._ewn_poly_to_new_mat_id = old_poly_id_to_new_mat_id
        self._new_mat_comp.data.commit()  # save the coverage uuid

    def update_adh_friction(self):
        """Update the adh friction and create a new coverage."""
        self._create_new_adh_materials_coverage()


def build_stamped_coverage(base_coverage, stamp_coverage, logger):
    """
    Builds a coverage by stamping the polygons from one into another.

    Args:
        base_coverage: The ADH coverage with polygons.
        stamp_coverage: The EWN coverage with polygons.
        logger: Logger for logging information.

    Returns:
        Merged coverage geometry that combines both coverages.
    """
    logger.info('Merging polygons from ADH and EWN coverages.')
    # Step 1: Create Shapely polygons from the base coverage
    base_shapely_polys = {}
    for polygon in base_coverage.polygons:
        poly_points_list = get_polygon_point_lists(polygon)
        poly_boundary = poly_points_list[0]
        poly_holes = poly_points_list[1:] if len(poly_points_list) > 1 else []
        s_poly = geometry.Polygon(poly_boundary, poly_holes)
        base_shapely_polys[polygon.id] = ShapelyPolygonInfo(id=polygon.id, polygon=s_poly, adh_id=polygon.id)

    # Step 2: Create Shapely polygons from stamp coverage
    stamp_shapely_polys = {}
    for polygon in stamp_coverage.polygons:
        poly_points_list = get_polygon_point_lists(polygon)
        poly_boundary = poly_points_list[0]
        poly_holes = poly_points_list[1:] if len(poly_points_list) > 1 else []
        s_poly = geometry.Polygon(poly_boundary, poly_holes)
        stamp_shapely_polys[polygon.id] = ShapelyPolygonInfo(id=polygon.id, polygon=s_poly, ewn_id=polygon.id)

    # Step 3: Subtract overlapping areas between base and stamp polygons
    for id, poly_info in base_shapely_polys.items():
        for ewn_id, ewn_poly_info in stamp_shapely_polys.items():
            # Subtract the stamp polygon from base polygon
            poly_info.polygon = poly_info.polygon.difference(ewn_poly_info.polygon)
            # Update the stamp polygon with the non-overlapping area
            ewn_diff = ewn_poly_info.polygon.difference(poly_info.polygon)
            ewn_poly_info.polygon = ewn_diff.union(ewn_poly_info.polygon)
            stamp_shapely_polys[ewn_id] = ewn_poly_info
        base_shapely_polys[id] = poly_info

    # Step 4: Merge stamp polygons into the base polygon list
    merged_polygons = {id: p for id, p in base_shapely_polys.items()}
    max_id = max([p.id for p in merged_polygons.values()])

    # Add the remaining stamp polygons to the merged list
    for _, poly_info in stamp_shapely_polys.items():
        max_id += 1
        poly_info.id = max_id
        merged_polygons[max_id] = poly_info

    # Step 5: Prepare for building the final coverage
    coverage_point_locations = []  # We will convert this to a set after gathering all locations
    for poly_info in merged_polygons.values():
        coverage_point_locations.extend(poly_info.locations)
    coverage_point_locations = list(set(coverage_point_locations))  # Unique points

    # Step 6: Build the multipolygon geometry
    multipolys = {}
    for _, poly_info in merged_polygons.items():
        mapper_id = (poly_info.adh_id, poly_info.ewn_id)
        exterior = poly_info.polygon.exterior.coords[:]
        interiors = [x.coords[:] for x in poly_info.polygon.interiors]

        # Convert exterior/interior to IDs corresponding to their location in coverage_point_locations
        exterior_by_id = [coverage_point_locations.index(pt) for pt in exterior]
        exterior_by_id.append(exterior_by_id[-1])  # Close the exterior ring

        multipolys[mapper_id] = [exterior_by_id]

        for interior in interiors:
            interior_by_id = [coverage_point_locations.index(pt) for pt in interior]
            interior_by_id.append(interior_by_id[-1])  # Close the interior ring
            multipolys[mapper_id].append(interior_by_id)
        multipolys[mapper_id] = [multipolys[mapper_id]]

    # Step 7: Use the PolygonCoverageBuilder to create the final merged coverage
    cov_builder = PolygonCoverageBuilder(coverage_point_locations, base_coverage.projection, 'temp', logger)
    polygon_merge_coverage = cov_builder.build_coverage(multipolys)
    dataset_polygon_ids = cov_builder.dataset_polygon_ids

    # Return the merged coverage geometry
    return polygon_merge_coverage, dataset_polygon_ids


def get_shapely_polygons(coverage):
    """Get the shapely polygons from a coverage.

    Args:
        coverage: The coverage.
    """
    polygons = {}
    for polygon in coverage.polygons:
        poly_points_list = get_polygon_point_lists(polygon)
        poly_boundary = poly_points_list[0]
        poly_holes = poly_points_list[1:] if len(poly_points_list) > 1 else []
        s_poly = geometry.Polygon(poly_boundary, poly_holes)
        polygons[polygon.id] = ShapelyPolygonInfo(id=polygon.id, polygon=s_poly, ewn_id=polygon.id)
    return polygons


def get_polygons_as_text(coverage):
    """Get shapely polygons as text.

    Args:
        coverage: The coverage.
    """
    polygons = get_shapely_polygons(coverage)
    text = ""
    for poly_id, poly in polygons.items():
        text += f"Polygon: {poly_id}\n"
        text += f"{poly.polygon}\n"
    return text
