"""PestObsCoverageBuilder class."""
__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import math
from pathlib import Path
import uuid

# 2. Third party modules
import orjson

# 3. Aquaveo modules
from xms.core.filesystem import filesystem as fs
from xms.coverage.components.obs_target_component_display import OBS_CATEGORY_NAMES, reg_key_from_obs_category
from xms.coverage.data.obs_target_data import ObsTargetData
from xms.coverage.file_io.obs_target_component_builder import ObsTargetComponentBuilder
from xms.coverage.grid.polygon_coverage_builder import PolygonCoverageBuilder
from xms.data_objects.parameters import Arc, Component, Coverage, Point
from xms.guipy import file_io_util
from xms.guipy.settings import SettingsManager

# 4. Local modules
from xms.mf6.file_io import io_util
from xms.mf6.file_io.pest import pest_obs_util
from xms.mf6.file_io.pest.pest_obs_results_reader import B2Map, Features, Obs, ObsResults
from xms.mf6.misc.settings import Settings

# Type aliases
BackupDispOptsDict = dict[str, dict[int, str]]  # {map cov uuid: dict{obs_category: filepath}}}


def build_coverages(pest_obs_dir: str, obs_results: ObsResults, min_max_ugrid_pt_z: tuple[float, float]) -> list[tuple]:
    """Builds the coverages.

    Args:
        pest_obs_dir: Path to directory containing all the PEST files.
        obs_results: All the data in the various pest obs files.
        min_max_ugrid_pt_z: Tuple of the min and max ugrid point z.

    Returns:
        list[tuple[]]
    """
    builder = PestObsCoverageBuilder(pest_obs_dir, obs_results, min_max_ugrid_pt_z)
    return builder.build_coverages()


class PestObsCoverageBuilder:
    """Creates coverages as part of a solution."""
    def __init__(self, pest_obs_dir: str, obs_results: ObsResults, min_max_ugrid_pt_z: tuple[float, float]):
        """Initializer.

        Args:
            pest_obs_dir: Path to directory containing all the PEST files.
            obs_results: All the data in the various pest obs files.
            min_max_ugrid_pt_z: Tuple of the min and max ugrid point z.
        """
        self._pest_obs_dir = pest_obs_dir
        self._obs_results: ObsResults = obs_results

        self._z = _compute_z(min_max_ugrid_pt_z)
        self._pt_hash = {}  # Dict of coverage uuid, point xyz, Point
        self._next_id = 1
        self._default_disp_opt_files: list[str] = []  # display option json files created from the registry defaults

    def build_coverages(self) -> list[tuple]:
        """Builds the coverages.

        Returns:
            list[tuple[]]
        """
        coverages = []

        # One coverage from GMS may have contained multiple feature types, but they were saved to separate
        # shapefiles. Or (in the future) the user may have started with shapefiles. So now we will put each
        # feature type in its own coverage.

        # Go through b2map in order of the coverage tree path
        b2map: B2Map = self._obs_results.b2map  # From .b2map file.
        for cov_uuid, cov_data in b2map.items():
            # cov_uuid is the coverage that was used to generate the PEST obs data (not the solution coverage)
            # Build up coverage features
            features: Features = cov_data['features']
            cov_name = pest_obs_util.cov_name_from_path(cov_data.get('coverage_tree_path', ''))
            if features['points']:
                self._reset_point_hash_and_starting_feature_id()
                coverages.append(self._create_point_coverage(cov_uuid, cov_name, features))
            if features['arcs']:
                self._reset_point_hash_and_starting_feature_id()
                coverages.append(self._create_arc_coverage(cov_uuid, cov_name, features))
            if features['arc_groups']:
                self._reset_point_hash_and_starting_feature_id()
                coverages.append(self._create_arc_group_coverage(cov_uuid, cov_name, features))
            if features['polygons']:
                self._reset_point_hash_and_starting_feature_id()
                coverages.append(self._create_polygon_coverage(cov_uuid, cov_name, features))
        self._delete_temp_files()
        return coverages

    def _delete_temp_files(self) -> None:
        """Delete all temp files that we no longer need."""
        # Delete the display options files we created from the registry
        for file in self._default_disp_opt_files:
            fs.removefile(file)

    def _reset_point_hash_and_starting_feature_id(self):
        self._pt_hash.clear()
        self._next_id = 1

    def _next_feature_id(self):
        """Just counts up from 1 and returns the next number."""
        id_ = self._next_id
        self._next_id += 1
        return id_

    def _display_options_file(self, obs_category: int, cov_uuid: str) -> str:
        """Return filepath of display options file.

        The file comes from 1 of 3 places, depending on what exists, in this order:
            1) the previous solution
            2) the default in the registry
            3) the xms.mf6 default

        Args:
            obs_category: ObsTargetData.OBS_CATEGORY_POINT etc. (0=point, 1=arc, 2=arc_group, 3=poly).
            cov_uuid: Uuid of the coverage used to generate the PEST obs data.

        Returns:
            The filepath.
        """
        if filepath := self._default_disp_opts_file(obs_category):
            return filepath
        else:
            return _default_disp_opts_filepath(obs_category)

    def _default_disp_opts_file(self, obs_category: int) -> str:
        """Return a temporary file containing the default display options in the registry, if found.

        Args:
            obs_category: ObsTargetData.OBS_CATEGORY_POINT etc. (0=point, 1=arc, 2=arc_group, 3=poly).

        Returns:
            Path to temp file.
        """
        settings = SettingsManager()
        reg_key = reg_key_from_obs_category(obs_category)
        json_text = settings.get_setting('xmscoverage', reg_key)
        if not json_text:
            return ''
        json_dict = orjson.loads(json_text)
        temp_filepath = io_util.get_temp_filename(suffix='.json')
        file_io_util.write_json_file(json_dict, temp_filepath)
        self._default_disp_opt_files.append(temp_filepath)
        return temp_filepath

    def _add_observations(self, alias_ids_and_intervals, builder: ObsTargetComponentBuilder, flow: bool) -> None:
        """Adds the observation data to the component builder."""
        obs: Obs = self._obs_results.obs
        for alias, id_and_interval in alias_ids_and_intervals.items():
            for date_time, obs_vals in obs[alias].items():
                if obs_vals['flow'] == flow:
                    if flow:
                        builder.current_disp_category = 1  # This is for flows
                    else:
                        builder.current_disp_category = 0  # This is for heads
                    builder.add_observation(
                        feature_id=id_and_interval['id'],
                        time=date_time,
                        interval=id_and_interval['interval'],
                        observed=obs_vals['observed'],
                        computed=obs_vals['computed'],
                        feature_name=alias,
                        label_text=alias
                    )

    def _build_obs_component(
        self, obs_category: int, alias_ids_and_intervals: dict, obs_cov_uuid: str, cov_uuid: str
    ) -> tuple:
        """Build the obs component and add the observations.

        Args:
            obs_category: ObsTargetData.OBS_CATEGORY_POINT etc. (0=point, 1=arc, 2=arc_group, 3=poly).
            alias_ids_and_intervals:
            obs_cov_uuid: The obs solution coverage uuid.
            cov_uuid: Uuid of the coverage used to generate the PEST obs data.

        Returns:
            Tuple.
        """
        builder = ObsTargetComponentBuilder(coverage_feature_type=obs_category)
        builder.custom_display_options[obs_category] = self._display_options_file(obs_category, cov_uuid)
        builder.current_feature_type = obs_category
        builder.dset_uuid = self._obs_results.dset_uuid  # Uuid of dataset associated with the observations
        flow = False if obs_category == ObsTargetData.OBS_CATEGORY_POINT else True
        self._add_observations(alias_ids_and_intervals, builder, flow)
        do_comp, comp_data = builder.build_obs_target_component(obs_cov_uuid, True)
        return do_comp, comp_data

    def _create_point_coverage(self, cov_uuid: str, cov_name: str,
                               features: Features) -> tuple[Coverage, Component, list[dict]]:
        """Create obs coverage and component.

        Args:
            cov_uuid: Uuid of the coverage used to generate the PEST obs data.
            cov_name: Name of the coverage used to generate the PEST obs data.
            features: Dict of feature type -> {alias -> geometry}.

        Returns:
            tuple of the results.
        """
        # Geometry
        do_points = []
        alias_ids_and_intervals = {}
        for alias, data in features['points'].items():
            do_point = Point(data['geometry'][0], data['geometry'][1], self._z, self._next_feature_id())
            do_points.append(do_point)
            alias_ids_and_intervals[alias] = self._id_interval_dict(do_point.id, data['interval'])

        obs_cov = Coverage()
        obs_cov.name = f'{cov_name} obs points' if cov_name else 'obs points'
        obs_cov.uuid = str(uuid.uuid4())
        obs_cov.set_points(do_points)
        obs_cov.complete()

        # Observations
        obs_category = ObsTargetData.OBS_CATEGORY_POINT
        do_comp, comp_data = self._build_obs_component(obs_category, alias_ids_and_intervals, obs_cov.uuid, cov_uuid)

        _save_map_cov_uuid(cov_uuid, do_comp.main_file)
        return obs_cov, do_comp, comp_data

    def _create_arc_coverage(self, cov_uuid: str, cov_name: str, features: Features):
        """Create obs coverage and component.

        Args:
            cov_uuid: Uuid of the coverage used to generate the PEST obs data.
            cov_name: Name of the coverage used to generate the PEST obs data.
            features: Dict of feature type -> {alias -> geometry}.

        Returns:
            tuple of the results.
        """
        # Geometry
        do_arcs, alias_ids_and_intervals = self._create_arcs(features['arcs'])
        obs_cov = Coverage()
        obs_cov.name = f'{cov_name} obs arcs' if cov_name else 'obs arcs'
        obs_cov.uuid = str(uuid.uuid4())
        obs_cov.arcs = do_arcs
        obs_cov.complete()

        # Observations
        obs_category = ObsTargetData.OBS_CATEGORY_ARC
        do_comp, comp_data = self._build_obs_component(obs_category, alias_ids_and_intervals, obs_cov.uuid, cov_uuid)

        _save_map_cov_uuid(cov_uuid, do_comp.main_file)
        return obs_cov, do_comp, comp_data

    def _create_arc_group_coverage(self, cov_uuid: str, cov_name: str, features: Features):
        """Create obs coverage and component.

        Note: We don't do anything to hash arcs.

        Duplicate arcs will be duplicated. We could try to match by geometry, but that will suffer from round-off
        and ordering differences that would be challenging. We don't have the arc IDs. Getting them is very hard
        because the arc_group coverage is exported to a shapefile and the arcs to another shapefile, so we don't have
        the arc ids.

        Args:
            cov_uuid: Uuid of the coverage used to generate the PEST obs data.
            cov_name: Name of the coverage used to generate the PEST obs data.
            features: Dict of feature type -> {alias -> geometry}.

        Returns:
            tuple of the results.
        """
        # Geometry
        self._next_id = 1
        do_arcs = []
        arc_groups = {}
        alias_ids_and_intervals = {}
        for alias, data in features['arc_groups'].items():
            # Create a dict to pass to _create_arcs()
            arc_dict = {}
            for i, point_list in enumerate(data['geometry']):
                arc_dict[str(i + 1)] = {'geometry': point_list, 'interval': 0.0}
            arcs, _arc_alias_ids = self._create_arcs(arc_dict)
            do_arcs.extend(arcs)
            arc_ids = [arc.id for arc in do_arcs]
            arc_group_id = max(arc_ids) + 1
            arc_groups[arc_group_id] = arc_ids  # Use the max arc ID plus one
            alias_ids_and_intervals[alias] = self._id_interval_dict(arc_group_id, data['interval'])

        obs_cov = Coverage()
        obs_cov.name = f'{cov_name} obs arc_groups' if cov_name else 'obs arc_groups'
        obs_cov.uuid = str(uuid.uuid4())
        obs_cov.arcs = do_arcs
        obs_cov.arc_groups = arc_groups
        obs_cov.complete()

        # Observations
        obs_category = ObsTargetData.OBS_CATEGORY_ARC_GROUP
        do_comp, comp_data = self._build_obs_component(obs_category, alias_ids_and_intervals, obs_cov.uuid, cov_uuid)

        _save_map_cov_uuid(cov_uuid, do_comp.main_file)
        return obs_cov, do_comp, comp_data

    def _create_polygon_coverage(self, cov_uuid: str, cov_name: str, features: Features):
        """Create obs coverage and component.

        Args:
            cov_uuid: Uuid of the coverage used to generate the PEST obs data.
            cov_name: Name of the coverage used to generate the PEST obs data.
            features: Dict of feature type -> {alias -> geometry}.

        Returns:
            tuple of the results.
        """
        # Geometry
        # We have to assume polygons are created in same order as multipolys and numbered 1 to n for the ids.
        point_locs, multipolys, alias_ids_and_intervals = self._convert_xyz_lists_to_indexed_lists(features['polygons'])
        coverage_name = f'{cov_name} obs polygons' if cov_name else 'obs polygons'
        poly_builder = PolygonCoverageBuilder(point_locs, projection=None, coverage_name=coverage_name, logger=None)
        obs_cov = poly_builder.build_coverage(multipolys)

        # Observations
        obs_category = ObsTargetData.OBS_CATEGORY_POLY
        do_comp, comp_data = self._build_obs_component(obs_category, alias_ids_and_intervals, obs_cov.uuid, cov_uuid)

        _save_map_cov_uuid(cov_uuid, do_comp.main_file)
        return obs_cov, do_comp, comp_data

    def _create_arcs(self, arcs):
        do_arcs = []
        alias_ids_and_intervals = {}
        for alias, data in arcs.items():
            arc_points = arc_geom(data['geometry'], self._z)
            do_arc = self._add_arc(self._next_feature_id(), arc_points)
            do_arcs.append(do_arc)
            alias_ids_and_intervals[alias] = self._id_interval_dict(do_arc.id, data['interval'])
        return do_arcs, alias_ids_and_intervals

    def _convert_xyz_lists_to_indexed_lists(self, polygons):
        """Converts polygons from lists of lists of xyz to lists of integer indexes into a list of points."""
        next_index = 0
        point_locs_hash = {}  # Dict of xyz -> id
        point_locs = []
        value = 1  # Arbitrary value that we ignore but is required by PolygonCoverageBuilder
        multi_polys = {value: []}
        alias_ids_and_intervals = {}
        polygon_id = 1
        for alias, data in polygons.items():
            polygon_arcs = poly_geom(data['geometry'], self._z)
            new_poly = []
            for arc in polygon_arcs:
                new_arc = []
                for point in arc:
                    index = point_locs_hash.get(point, next_index)
                    if index == next_index:
                        point_locs_hash[point] = next_index
                        next_index += 1
                    new_arc.append(index)
                new_poly.append(new_arc)
            multi_polys[value].append(new_poly)
            alias_ids_and_intervals[alias] = self._id_interval_dict(polygon_id, data['interval'])
            polygon_id += 1

            # Convert point hash to list of point locs in order
            point_locs = [None] * len(point_locs_hash)
            for point, index in point_locs_hash.items():
                point_locs[index] = point
        return point_locs, multi_polys, alias_ids_and_intervals

    def _id_interval_dict(self, id, interval):
        return {'id': id, 'interval': interval}

    def _add_point(self, feature_id, point):
        """Adds a Point to the hash and returns the new Point."""
        do_point = Point(point[0], point[1], point[2], feature_id)
        self._pt_hash[point] = do_point
        return do_point

    def _add_arc(self, feature_id, arc_points):
        start_node = self._get_or_hash_point(arc_points[0])
        end_node = self._get_or_hash_point(arc_points[-1])
        vertices = []
        for point in arc_points[1:-1]:
            vertices.append(self._get_or_hash_point(point))
        do_arc = Arc(feature_id, start_node, end_node, vertices)
        return do_arc

    def _get_or_hash_point(self, point):
        """Returns a Point by getting it from the hash or creating a new one and adding it to the hash."""
        if point in self._pt_hash:
            do_point = self._pt_hash[point]
        else:
            do_point = self._add_point(self._next_feature_id(), point)
        return do_point


def point_geom(geom: list[float], z: float | None) -> tuple[float, float, float]:
    """Return an xyz tuple, e.g. (1.0, 2.0, 3.0) given an xyz list, e.g. [1.0, 2.0, 3.0], and possible z value.

    Args:
        geom: The xyz values.
        z: If provided, z value will be set to this.

    Returns:
        See description.
    """
    if z is not None:
        return geom[0], geom[1], z
    else:
        return geom[0], geom[1], geom[2]


def arc_geom(geom: list[list[float]], z: float | None):
    """Return xyz tuples, e.g. (1.0, 2.0, 3.0) given an xyz list, e.g. [1.0, 2.0, 3.0], and possible z value.

    Args:
        geom: The xyz values.
        z: If provided, z value will be set to this.

    Returns:
        See description.
    """
    return [point_geom(point, z) for point in geom]


def poly_geom(geom: list[list[list[float]]], z: float | None):
    """Returns list [[(1.0, 2.0, 3.0), (1.0, 2.0, 3.0)], [(1.0, 2.0, 3.0), (1.0, 2.0, 3.0)]] given similar string.

    Example input: [[[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], [[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]]
    Example output: [[(1.0, 2.0, 3.0), (1.0, 2.0, 3.0)], [(1.0, 2.0, 3.0), (1.0, 2.0, 3.0)]]
    """
    return [arc_geom(arc, z) for arc in geom]


def _compute_z(min_max_ugrid_pt_z: tuple[float, float]) -> float:
    """Compute and return the world Z where all feature objects will be located, which will be above the UGrid.

    Tries to come up with a nice, round number, at least 5% to 10% above the grid.

    Args:
        min_max_ugrid_pt_z: Tuple of the min and max ugrid point z.

    Returns:
        See description.
    """
    diff = (min_max_ugrid_pt_z[1] - min_max_ugrid_pt_z[0])
    if diff < 0.0 or math.isclose(diff, 0.0):
        return 0.0

    # Each step is spelled out to help in understanding
    percent_of_total_z = 0.05  # 5%
    raw_z = min_max_ugrid_pt_z[1] + (diff * percent_of_total_z)
    log_result = math.log10(diff)
    log_rounded_down = math.floor(log_result)
    n = log_rounded_down - 1.0  # n to use in 10^n
    ten_to_the_n = math.pow(10.0, n)  # 1.0, 10.0, 100.0 etc.
    nice_z = math.ceil(raw_z / ten_to_the_n) * ten_to_the_n  # Round the number to the nearest ten_to_the_n
    return nice_z


def _save_map_cov_uuid(cov_uuid: str, main_file: str) -> None:
    """Save the map coverage uuid in the settings.json file of the obs coverage component.

    We do this so next time we read solution, we can backup and restore the display options.

    Args:
        cov_uuid: Uuid of the coverage used to generate the PEST obs data.
        main_file: The data objects component main file.
    """
    Settings.set(main_file, 'MAP_COVERAGE_UUIDS', cov_uuid)


def _default_disp_opts_filepath(obs_category: int) -> str:
    """Return our default display options filepath associated with the obs_category.

    Args:
        obs_category: ObsTargetData.OBS_CATEGORY_POINT etc. (0=point, 1=arc, 2=arc_group, 3=poly).

    Returns:
        See description.
    """
    feature_type = OBS_CATEGORY_NAMES[obs_category]
    filepath = Path(__file__).parent.parent.parent / 'components' / 'resources' / 'display_options'
    filepath = filepath / f'obs_{feature_type}_disp.json'
    return str(filepath)
