"""IO for SCHISM horizontal grid files."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from logging import Logger
from typing import Optional, TextIO, Tuple

# 2. Third party modules

# 3. Aquaveo modules
from xms.constraint import UGrid2d
from xms.grid.ugrid import UGrid

# 4. Local modules


def write_horizontal_grid_file(file: TextIO, grid: UGrid, grid_name: str, logger: Logger) -> None:
    """Write the horizontal grid file.

    Args:
        file: The output file.
        grid: The input grid.
        grid_name: The name of the grid.
        logger: The logger.
    """
    logger.info('Writing horizontal grid file.')
    file.write(f'{grid_name}  ! description\n')
    file.write(f'{grid.cell_count} {grid.point_count}  ! number of elements and nodes\n')
    logger.info(f'Writing {grid.point_count} points.')
    for index, location in enumerate(grid.locations):
        file.write(f'{index + 1} {location[0]} {location[1]} {-location[2]}\n')
    logger.info(f'Writing {grid.cell_count} cells.')
    for cell_index in range(grid.cell_count):
        cell_points = grid.get_cell_points(cell_index)
        points = ' '.join([str(p + 1) for p in cell_points])
        file.write(f'{cell_index + 1} {len(cell_points)} {points}\n')


def read_horizontal_grid_file(file: TextIO, logger: Logger) -> Optional[UGrid2d]:
    """
    Read a horizontal grid file.

    Args:
        file: The file to read.
        logger: The logger.

    Returns:
        The grid.
    """
    cell_count, point_count = read_horizontal_grid_header(file)
    logger.info(f'Reading {point_count} points.')
    points = read_horizontal_grid_points(file, point_count)
    logger.info(f'Reading {cell_count} cells.')
    cellstream = read_horizontal_grid_cells(file, cell_count)
    ugrid = UGrid(points=points, cellstream=cellstream)
    grid = UGrid2d(ugrid=ugrid)
    return grid


def remove_comment(line: str) -> str:
    """
    Remove a comment from a line.

    Args:
        line: The line to remove the comment from.

    Returns:
        The line without the comment.
    """
    return line.split('!')[0].strip()


def read_horizontal_grid_header(file: TextIO) -> (int, int):
    """
    Read the header of a horizontal grid file.

    Args:
        file: The file to read.

    Returns:
        The number of cells and the number of points.
    """
    # skip header comment
    file.readline()
    # read number of nodes and elements
    line = remove_comment(file.readline())
    counts = [int(word) for word in line.split()]
    if len(counts) < 2:
        raise RuntimeError('Missing number of nodes and elements.')
    return counts[0], counts[1]


def read_horizontal_grid_points(file: TextIO, num_nodes: int) -> list[Tuple[float, float, float]]:
    """
    Read horizontal grid points.

    Args:
        file: The file to read.
        num_nodes: The number of nodes.

    Returns:
        List of the points.
    """
    points = []
    for _ in range(num_nodes):
        line = remove_comment(file.readline())
        coordinates = [float(word) for word in line.split()]
        # line has node number, x, y, depth
        if len(coordinates) < 4:
            raise RuntimeError('Missing node coordinates.')
        points.append((coordinates[1], coordinates[2], -coordinates[3]))
    return points


def read_horizontal_grid_cells(file: TextIO, cell_count: int) -> list[int]:
    """
    Read horizontal grid cells.

    Args:
        file: The file to read.
        cell_count: The number of cells.

    Returns:
        List of the cells.
    """
    cellstream = []
    for _ in range(cell_count):
        line = remove_comment(file.readline())
        cell_items = [int(word) for word in line.split()]
        # line has cell number, number of nodes, node1, node2, ...
        if len(cell_items) < 5:
            raise RuntimeError('Missing cell nodes.')
        # stream has cell type, number of nodes, node1, node2, ...
        if cell_items[1] == 3:
            cellstream.append(UGrid.cell_type_enum.TRIANGLE.value)
        elif cell_items[1] == 4:
            cellstream.append(UGrid.cell_type_enum.QUAD.value)
        else:
            raise RuntimeError('Invalid cell type.')
        cellstream.append(cell_items[1])
        cellstream.extend([cell_id - 1 for cell_id in cell_items[2:]])
    return cellstream
