"""Utilities for file I/O."""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules
from typing import Iterable

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.tree import tree_util, TreeNode
from xms.constraint import read_grid_from_file, UGrid3d

# 4. Local modules
from xms.hgs.components import dmi_util


def find_and_read_ugrids(sim_node, query) -> tuple['UGrid3d | None', 'UGrid3d | None']:
    """Finds and returns the ugrids.

    Args:
        sim_node (TreeNode): The simulation tree node.
        query (Query): Object for communicating with GMS

    Returns:
        (tuple[UGrid3d | None, UGrid3d | None]): The 3D and 2D grids.
    """
    ugrid_nodes = tree_util.descendants_of_type(tree_root=sim_node, xms_types=['TI_UGRID_PTR'], allow_pointers=True)
    if not ugrid_nodes:
        return None, None
    co_grid_2d = None
    co_grid_3d = None
    for ugrid_node in ugrid_nodes:
        ugrid_filename = dmi_util.get_ugrid_filename(ugrid_node.uuid, query)
        co_grid = read_grid_from_file(ugrid_filename)
        if co_grid.check_all_cells_2d():
            co_grid_2d = co_grid
        else:
            co_grid_3d = co_grid
    return co_grid_3d, co_grid_2d


def get_coverage_pointers(sim_node: TreeNode, coverage_type: str) -> list[TreeNode]:
    """Returns the coverage pointers under the sim.

    Returns:
        (list[TreeNode]): See description.
    """
    return tree_util.descendants_of_type(
        tree_root=sim_node,
        xms_types=['TI_COVER_PTR'],
        allow_pointers=True,
        only_first=False,
        coverage_type=coverage_type if coverage_type else None,
        model_name='HydroGeoSphere'
    )


def get_coverage_component(coverage_ptr_node, component_class, unique_name, query):
    """Returns the coverage component object."""
    do_comp = query.item_with_uuid(coverage_ptr_node.uuid, unique_name=unique_name, model_name='HydroGeoSphere')
    if not do_comp:
        return None
    component = component_class(do_comp.main_file)
    return component


def make_string_unique(name: str, name_set: set[str]) -> str:
    """Changes the name to be unique if necessary.

    Args:
        name (str): The BC name.
        name_set (set[str]): Set of strings.

    Returns:
        (str): A unique name.
    """
    # Append a number to the end until we have something that's unique
    i = 1
    candidate = name
    while candidate in name_set:
        candidate = f'{name}-{i}'
        i += 1
    new_name = candidate
    name_set.add(new_name)
    return new_name


def skip_comments(file, line: str | None) -> str:
    """Skips comment lines and returns the first line after the comment lines.

    Args:
        file (TextIOWrapper): An open file.
        line (str): Starting line, or None.

    Returns:
        (str): First line after the comments.
    """
    while line.lower().startswith('#'):
        line = next(file)
    return line


def skip_to(file, starting_strs: Iterable[str], line: str | None) -> tuple[str, str]:
    """Skips lines that don't start with starting_str  - case-insensitive - and returns the first line that does.

    Args:
        file (TextIOWrapper): An open file.
        starting_strs (Iterable[str]): List-like set of strings at the start of the line to look for.
        line (str): Starting line, or None.

    Returns:
        (tuple[str, str]): First line after, and the string in starting_strs that was found.
    """
    if isinstance(starting_strs, str):
        raise RuntimeError('skip_to starting_strs should be a list-like object, not a str.')  # pragma no cover

    if line is None:
        line = next(file)
    title_str = ''
    done = False
    while line and not done:
        for starting_str in starting_strs:
            if line.lower().startswith(starting_str.lower()):
                title_str = starting_str
                done = True
                break
        if not done:
            line = next(file)
    return line, title_str


def read_units(file) -> tuple[dict[int, str], str]:
    """Reads the units.

    Args:
        file (TextIOWrapper): An open file.

    Returns:
        (tuple[dict[int, str], str]): List of unit strings, the last line read.
    """
    line, _ = skip_to(file, ['VARAUXDATA', 'ZONE'], line=None)
    if line and line.lower().startswith('zone'):
        return {}, line

    units = {}
    while line.startswith('VARAUXDATA'):
        # I tried using a regex but then I had two problems
        # Read the number
        pos1 = line.find(' ')
        pos2 = line.find(' ', pos1 + 1)
        number = int(line[pos1:pos2])

        # Read the units string
        pos1 = line.find('"')
        pos2 = line.find('"', pos1 + 1)
        unit_str = line[pos1 + 1:pos2]
        if unit_str != '(-)':
            units[number] = unit_str
        line = next(file)
    return units, line
