"""Read and write AdH data for testing."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from io import TextIOWrapper
import json
from pathlib import Path
import shutil
import sys
from typing import Callable, Optional, TextIO

# 2. Third party modules
import pandas as pd
import param
import xarray as xr

# 3. Aquaveo modules
from xms.constraint import read_grid_from_file
from xms.data_objects.parameters import Coverage
from xms.datasets.dataset_reader import DatasetReader
from xms.guipy.data.target_type import TargetType

# 4. Local modules
from xms.adh.components.adh_component import AdHComponent
from xms.adh.components.bc_conceptual_component import BcConceptualComponent
from xms.adh.components.material_conceptual_component import MaterialConceptualComponent
from xms.adh.components.sediment_constituents_component import SedimentConstituentsComponent
from xms.adh.components.sediment_material_conceptual_component import SedimentMaterialConceptualComponent
from xms.adh.components.sim_component import SimComponent
from xms.adh.components.transport_constituents_component import TransportConstituentsComponent
from xms.adh.components.vessel_component import VesselComponent
from xms.adh.data.adh_query_data import AdhQueryData
from xms.adh.data.xms_data import XmsData
from xms.adh.file_io.io_util import ensure_empty_directory


def _save_info(path: Path, saved_info: dict):
    """
    Writes the saved information to the given folder path.

    Args:
        path: The directory path where the JSON file will be saved.
        saved_info: The dictionary containing the saved information.
    """
    save_file = path / 'info.json'
    info_str = json.dumps(saved_info, indent=4)
    with open(save_file, 'w') as f:
        f.write(info_str)


def _save_component(path: Path, component: AdHComponent, name: str, saved_info: dict):
    """
    Save a component file to the path.

    Args:
        path: The path to save to.
        component: The component.
        name: The component name.
        saved_info: The dictionary containing the saved information.
    """
    if component is None:
        return
    info = {}
    old_file = Path(component.main_file)
    new_file = path / old_file.name
    info['file_path'] = str(old_file.name)
    info['uuid'] = component.uuid
    shutil.copy(old_file, new_file)
    if component.comp_to_xms:
        if component.comp_to_xms:
            json_file = new_file.with_suffix('.json')
            json_text = json.dumps(component.comp_to_xms, indent=4)
            with open(json_file, 'w') as f:
                f.write(json_text)
            info['comp_to_xms'] = json_file.name
    saved_info[name] = info


def _save_coverage(path: Path, coverage: Coverage, name: str, saved_info: dict):
    """
    Save the coverage data to the given file path.

    Args:
        path (Path): The file system path where the coverage
            data should be saved.
        coverage (Coverage): The coverage object containing the
            coverage information to be saved.
        name: The coverage name.
        saved_info: The dictionary containing the saved information.
    """
    if coverage is None:
        return
    new_file = path / f'{name}.h5'
    coverage.write_h5(str(new_file))
    info = {'file_path': new_file.name}
    saved_info[name] = info


def _save_datasets(path: Path, datasets: dict[str, DatasetReader], saved_info: dict):
    """
    Save the datasets to the given file path.

    Args:
        path: The path to save to.
        datasets: The datasets to be saved.
        saved_info: The dictionary containing the saved information.
    """
    dataset_info = {}
    for uuid, dataset_reader in datasets.items():
        new_file_name = f'{uuid}.h5'
        dataset_info[uuid] = {'file_path': new_file_name, 'group_path': dataset_reader.group_path}
        old_file = Path(dataset_reader.h5_filename)
        new_file = path / new_file_name
        shutil.copyfile(old_file, new_file)
    saved_info['datasets'] = dataset_info


def _load_info(path: Path) -> dict:
    """Load the saved metadata from the info.json file."""
    info_file = path / 'info.json'
    with open(info_file, 'r') as f:
        return json.load(f)


def _load_component(path: Path, saved_info: dict, name: str, component_class: Callable) -> Optional['AdHComponent']:
    """Load a component from the saved files."""
    component = None
    if name in saved_info:
        file_path = path / saved_info[name]['file_path']
        uuid_path = path / saved_info[name]['uuid']
        uuid_path.mkdir()
        uuid_file_path = uuid_path / file_path.name
        shutil.copy(file_path, uuid_file_path)
        component = component_class(str(uuid_file_path))
        if 'comp_to_xms' in saved_info[name]:
            comp_to_xms_path = path / saved_info[name]['comp_to_xms']
            with open(comp_to_xms_path, 'r') as f:
                comp_to_xms = json.load(f)
            # make the coverage target type an integer
            new_comp_to_xms = {}
            for uuid, target_values in comp_to_xms.items():
                uuid_values = {}
                for target_type, values_dict in target_values.items():
                    map_values = {}
                    for key, value in values_dict.items():
                        map_values[int(key)] = value
                    uuid_values[int(target_type)] = map_values
                new_comp_to_xms[uuid] = uuid_values
            component.comp_to_xms = new_comp_to_xms
            if isinstance(component, BcConceptualComponent):
                component.clean_attributes()
    return component


def _load_coverage(path: Path, saved_info: dict, name: str) -> Optional['Coverage']:
    """Load a coverage file and return the coverage instance."""
    if name in saved_info:
        file_path = path / saved_info[name]['file_path']
        coverage = Coverage(str(file_path))
        return coverage
    return None


def _load_datasets(path: Path, saved_info: dict) -> dict[str, 'DatasetReader']:
    """Load datasets and return a dictionary of dataset readers."""
    datasets = {}
    for uuid, info in saved_info['datasets'].items():
        h5_filename = path / info['file_path']
        group_path = info['group_path']
        datasets[uuid] = DatasetReader(str(h5_filename), group_path=group_path)
    return datasets


def save_xms_data(base_path: str, components: dict, coverages: dict, xms_data: XmsData):
    """Saves simulation data to the specified file path.

    Args:
        base_path: The base directory where data will be saved.
        components: The components to save.
        coverages: The coverages to save.
        xms_data: The XMS data.
    """
    path = Path(base_path).resolve()
    ensure_empty_directory(path)

    # Initialize saved information
    saved_info = {'sim_name': xms_data.sim_name}

    # Save CoGrid
    if xms_data.co_grid:
        co_grid_path = path / 'co_grid.xmc'
        xms_data.co_grid.write_to_file(str(co_grid_path), binary_arrays=False)

    # Save components
    for name, component in components.items():
        _save_component(path, component, name, saved_info)

    # Save coverages
    for name, coverage in coverages.items():
        _save_coverage(path, coverage, name, saved_info)

    # Save datasets
    datasets_path = path / "Datasets"
    datasets_path.mkdir()
    _save_datasets(datasets_path, xms_data.dataset_readers, saved_info)

    # Save metadata
    _save_info(path, saved_info)


def load_xms_data(path: Path) -> XmsData:
    """
    Loads the XMS data from a given path.

    Args:
        path (Path): The directory where XMS data is stored.

    Returns:
        XmsData: The loaded XMS AdH data.
    """
    xms_data = XmsData()
    xms_data.adh_data = AdhQueryData(xms_data)
    saved_info = _load_info(path)

    # Set basic attributes
    xms_data.sim_name = saved_info.get('sim_name')

    # Load CoGrid
    co_grid_path = path / 'co_grid.xmc'
    if co_grid_path.exists():
        xms_data.co_grid = read_grid_from_file(str(co_grid_path))

    # Load and set components
    xms_data.sim_component = _load_component(path, saved_info, 'sim_component', SimComponent)
    xms_data.bc_component = _load_component(path, saved_info, 'bc_component', BcConceptualComponent)
    xms_data.material_component = _load_component(path, saved_info, 'material_component', MaterialConceptualComponent)
    xms_data.sediment_material_component = _load_component(
        path, saved_info, 'sediment_material_component', SedimentMaterialConceptualComponent
    )
    xms_data.transport_component = _load_component(
        path, saved_info, 'transport_component', TransportConstituentsComponent
    )
    xms_data.sediment_constituents_component = _load_component(
        path, saved_info, 'sediment_constituents_component', SedimentConstituentsComponent
    )
    if 'vessel_components' in saved_info:
        saved_components = saved_info['vessel_components']
        components = []
        for info in saved_components:
            component_info = {'vessel_component': info}
            vessel_component = _load_component(path, component_info, 'vessel_component', VesselComponent)
            # Map component id's found in the json
            component_id_map_raw = info.get('component_id_map', {})
            component_id_map = {}
            for target_type_str, mapping in component_id_map_raw.items():
                target_type_enum = getattr(TargetType, target_type_str)
                component_id_map[target_type_enum] = {int(k): int(v) for k, v in mapping.items()}

            vessel_component.data.component_id_map = component_id_map
            components.append(vessel_component)
        xms_data.vessel_components = components

    # Load and set coverages
    xms_data.bc_coverage = _load_coverage(path, saved_info, 'bc_coverage')
    xms_data.material_coverage = _load_coverage(path, saved_info, 'material_coverage')
    xms_data.output_coverage = _load_coverage(path, saved_info, 'output_coverage')
    xms_data.sediment_material_coverage = _load_coverage(path, saved_info, 'sediment_material_coverage')
    if 'vessel_coverages' in saved_info:
        saved_coverages = saved_info['vessel_coverages']
        coverages = []
        for info in saved_coverages:
            coverage_info = {'vessel_coverage': info}
            vessel_coverage = _load_coverage(path, coverage_info, 'vessel_coverage')
            coverages.append(vessel_coverage)
        xms_data.vessel_coverages = coverages

    # Load and set datasets
    datasets_path = path / "Datasets"
    xms_data.dataset_readers = _load_datasets(datasets_path, saved_info)
    return xms_data


def print_vars_recursive(
    obj: object, current_depth: int = 0, max_depth: int = 9, file: Optional[TextIO] = None, visited=None
) -> None:
    """
    Recursively prints the attributes of an object.

    Uses `vars()`, with special handling for param.Parameterized objects, pandas DataFrames, and xarray Datasets.

    Args:
        obj: The object to inspect.
        current_depth: The current recursion depth (default: 0).
        max_depth: The maximum recursion depth to prevent infinite loops (default: 2).
        file: A writable object (e.g., file, sys.stdout).
            If None, defaults to standard output.
        visited: Set to track visited objects to prevent circular references.

    Returns:
        None: Prints the attributes and their values to the specified output.
    """
    # Indentation for nested levels
    indent = "  " * current_depth

    # Default to standard output if no output is provided
    if file is None:
        file = sys.stdout

    if visited is None:
        visited = set()

    # Stop recursion if max depth is reached
    if current_depth > max_depth:
        print(f"{indent}Reached max depth.", file=file)
        return

    # Special handling for xarray.Dataset
    if isinstance(obj, xr.Dataset):
        print(f"{indent}xarray.Dataset with dimensions {obj.dims} and variables {list(obj.data_vars)}:", file=file)
        print(f"{indent}  Coordinates:", file=file)
        for coord_name, coord_values in obj.coords.items():
            print(f"{indent}    {coord_name}: {coord_values.values}", file=file)
        print(f"{indent}  Variables:", file=file)
        for var_name, var_data in obj.data_vars.items():
            print(f"{indent}    {var_name}:", file=file)
            print_vars_recursive(var_data.values, current_depth + 1, max_depth, file)
        print(f"{indent}  Attributes:", file=file)
        for attr_name, attr_value in obj.attrs.items():
            print(f"{indent}    {attr_name}: {attr_value}", file=file)
        return

    # Special handling for pandas DataFrame
    if isinstance(obj, pd.DataFrame):
        print(f"{indent}pandas.DataFrame with shape {obj.shape}:", file=file)
        with pd.option_context('display.max_rows', None, 'display.max_columns', None):
            print(obj, file=file)
        return

    # Special handling for param.Parameterized objects
    if isinstance(obj, param.Parameterized):
        print(f"{indent}param.Parameterized object with parameters:", file=file)
        for name, parameter in obj.param.objects("existing").items():
            value = getattr(obj, name)
            if isinstance(value, param.Parameterized):
                # If the value is another Parameterized object, process it recursively
                print(f"{indent}  {name} ({type(parameter).__name__}):", file=file)
                print_vars_recursive(value, current_depth + 1, max_depth, file, visited)
            else:
                # Clean up the representation of the value
                value_repr = repr(value)
                if name == "name" and str(type(parameter).__name__) == "String":
                    value_repr = value_repr.strip("'")
                    value_repr = value_repr.rstrip("0123456789")
                print(f"{indent}  {name} ({type(parameter).__name__}): {value_repr}", file=file)
        return

    # If the object is a basic type or does not have __dict__, print it directly
    if not hasattr(obj, "__dict__") and not isinstance(obj, (list, dict, tuple, set)):
        print(f"{indent}{repr(obj)}", file=file)
        return

    # Handle lists, tuples, and sets by iterating over their elements
    if isinstance(obj, (list, tuple, set)):
        for index, item in enumerate(obj):
            print(f"{indent}[{index}]", file=file)
            print_vars_recursive(item, current_depth + 1, max_depth, file)
        return

    # Handle dictionaries by iterating over key-value pairs
    if isinstance(obj, dict):
        for key, value in obj.items():
            print(f"{indent}{key}:", file=file)
            print_vars_recursive(value, current_depth + 1, max_depth, file)
        return

    # Handle objects with attributes (those with a __dict__)
    for attr, value in vars(obj).items():
        print(f"{indent}{attr}:", file=file)
        print_vars_recursive(value, current_depth + 1, max_depth, file)


def print_xms_data(base_path: str, components: dict[str, object], _coverages: dict[str, object], _xms_data: XmsData):
    """
    Prints XMS data for debugging purposes.

    Args:
        base_path: The base directory where data will be saved.
        components: The components to save.
        _coverages: The coverages to save.
        _xms_data: The XMS data.
    """
    path = Path(base_path).resolve()
    with open(path / "debug.txt", "w") as f:
        f: TextIOWrapper
        print(f"XMS data at {path}:", file=f)
        print(f"  Sim name: {path.name}", file=f)
        print("  Components:", file=f)
        for name, component in components.items():
            print(f"    {name}:", file=f)
            print_vars_recursive(component, file=f)
