"""Read and write the _gen2dm_extra.h5 file that goes with a .2dm file."""

# 1. Standard Python modules
from pathlib import Path
from typing import Optional

# 2. Third party modules
import h5py
import numpy as np

# 3. Aquaveo modules
from xms.data_objects.parameters import Projection

# 4. Local modules

datums = [
    'LOCAL',
    'NGVD29',
    'NAVD88',
]

units = [
    'FEET (U.S. SURVEY)',
    'FEET (INTERNATIONAL)',
    'METERS',
    'INCHES',
    'CENTIMETERS',
]


def read_extra_file(mesh_name, mesh_path: Path) -> tuple[Optional[str], Optional[Projection]]:
    """
    Read the contents of a *_gen2dm_extra.h5 file that belongs to a *.2dm file.

    This function fails gracefully if the file is invalid or does not exist.

    Args:
        mesh_name: Name of the mesh the file is being read for. Should match the MESHNAME card in the .2dm file.
        mesh_path: Path to the .2dm file being read.

    Returns:
        Tuple of (uuid, projection).
            - uuid: The UUID to assign to the mesh, or `None` if not found.
            - projection: The projection to assign to the mesh, or `None` if no projection info.
    """
    h5_path = mesh_path.parent / (mesh_path.stem + '_gen2dm_extra.h5')
    if not h5_path.exists():
        return None, None  # We don't have an extra file.

    file = h5py.File(h5_path)
    if '/Extra' not in file:
        return None, None  # This isn't an extra file.
    if f'/{mesh_name}' not in file:
        return None, None  # This file isn't for the expected mesh.

    group = file[f'/{mesh_name}']
    control_file = str(group['Control File'][0], encoding='latin-1')
    if control_file != mesh_path.name:
        return None, None  # This file isn't for our mesh file.

    mesh_uuid = str(group['PROPERTIES/GUID'][0], encoding='latin-1')

    coordinates = group['Coordinates']
    version = coordinates.attrs['Version'][0]
    if version != 2:
        return mesh_uuid, None  # We don't support versions other than 2.

    projection = Projection()
    have_projection = False
    if 'WKT' in coordinates.attrs:
        well_known_text = coordinates.attrs['WKT'][0]
        projection.well_known_text = well_known_text
        have_projection = True

    if 'HorizontalUnits' in coordinates.attrs:
        horizontal_units_index = coordinates.attrs['HorizontalUnits'][0]
        horizontal_units = units[horizontal_units_index]
        projection.horizontal_units = horizontal_units
        have_projection = True

    if 'VerticalUnits' in coordinates.attrs:
        vertical_units_index = coordinates.attrs['VerticalUnits'][0]
        vertical_units = units[vertical_units_index]
        projection.vertical_units = vertical_units
        have_projection = True

    if 'VerticalDatum' in coordinates.attrs:
        vertical_datum_index = coordinates.attrs['VerticalDatum'][0]
        vertical_datum = datums[vertical_datum_index]
        projection.vertical_datum = vertical_datum
        have_projection = True

    return mesh_uuid, projection if have_projection else None


def write_extra_file(
    mesh_name: str, mesh_path: Path, mesh_uuid: str, projection: Optional[Projection], xms_version: str
):
    """
    Write a *_gen2dm_extra.h5 file for a *.2dm file.

    The extra file contains information about the mesh projection and UUID.

    Args:
        mesh_name: Name of the mesh to write the file for.
        mesh_path: Path to the .2dm file the mesh was written to.
        mesh_uuid: UUID of the mesh being written.
        projection: Projection the mesh is currently in.
        xms_version: The version of XMS the file is being written by.
    """
    h5_path = mesh_path.parent / (mesh_path.stem + '_gen2dm_extra.h5')
    h5_path.unlink(missing_ok=True)
    file = h5py.File(h5_path, mode='w')

    _write_value(file, '/File Type', 'Xmdf')

    version = float(xms_version[:xms_version.rfind('.')])
    _write_value(file, '/File Version', version)
    extra_group = file.create_group('Extra')
    _write_attribute(extra_group, 'Grouptype', 'GEN2DM Model Params Extra')

    mesh_group = file.create_group(f'/{mesh_name}')
    _write_attribute(mesh_group, 'Grouptype', 'Generic')

    _write_value(file, f'/{mesh_name}/Control File', mesh_path.name)

    properties_group = file.create_group(f'/{mesh_name}/PROPERTIES')
    _write_attribute(properties_group, 'Grouptype', 'PROPERTIES')
    _write_value(file, f'/{mesh_name}/PROPERTIES/GUID', mesh_uuid)
    _write_value(file, f'/{mesh_name}/PROPERTIES/Visible', True)

    coords = file.create_group(f'/{mesh_name}/Coordinates')
    _write_attribute(coords, 'Grouptype', 'Coordinates')
    _write_attribute(coords, 'Version', 2)
    if projection and projection.well_known_text:
        _write_attribute(coords, 'WKT', projection.well_known_text)
    elif projection:
        _write_attribute(coords, 'Local', True)
    else:
        return  # No projection, so don't write anything.

    if projection.horizontal_units:
        horizontal_units = units.index(projection.horizontal_units)
        _write_attribute(coords, 'HorizontalUnits', horizontal_units)
    if projection.vertical_datum:
        vertical_datum = datums.index(projection.vertical_datum)
        _write_attribute(coords, 'VerticalDatum', vertical_datum)
    if projection.vertical_units:
        vertical_units = units.index(projection.vertical_units)
        _write_attribute(coords, 'VerticalUnits', vertical_units)


def _write_value(group: h5py.Group, path: str, value):
    if isinstance(value, str):
        dtype = None
        value = value.encode('latin-1')
    elif isinstance(value, float):
        dtype = np.dtype('float32')
    elif isinstance(value, int):
        dtype = np.dtype('int32')
    else:  # pragma: no cover
        raise AssertionError('Unknown type')

    group.create_dataset(path, dtype=dtype, data=[value])


def _write_attribute(group: h5py.Group, name: str, value):
    if isinstance(value, str):
        dtype = None
        value = value.encode('latin-1')
    elif isinstance(value, bool):
        dtype = np.dtype('int32')
        value = int(value)
    elif isinstance(value, int):
        dtype = np.dtype('int32')
    else:  # pragma: no cover
        raise AssertionError('Unknown type')

    group.attrs.create(name, [value], dtype=dtype)
