"""Module for reading VGrids."""

__copyright__ = "(C) Copyright Aquaveo 2024"
__license__ = "All rights reserved"

# 1. Standard Python modules
from pathlib import Path
from typing import Sequence

# 2. Third party modules

# 3. Aquaveo modules

# 4. Local modules
from xms.schism.file_io.reader_base import ReaderBase
from xms.schism.file_io.vgrid.vgrid_file import VGridFile


def read_vgrid(path: Path | str) -> VGridFile:
    """
    Read a VGrid from disk.

    Args:
        path: Path to the file to read.

    Returns:
        The contents of the file.
    """
    with VGridReader(path) as reader:
        reader.read()

        vgrid = VGridFile(
            hc=reader.hc,
            theta_b=reader.theta_b,
            theta_f=reader.theta_f,
            z_levels=reader.z_levels,
            s_levels=reader.s_levels
        )

        return vgrid


class VGridReader(ReaderBase):
    """Class for reading a VGrid file."""
    def __init__(self, path: Path):
        """
        Initialize the reader.

        Args:
            path: Where to read from.
        """
        super().__init__(path, comment_markers=['!'])

        self.hc = 0.0
        self.theta_b = 0.0
        self.theta_f = 0.0
        self.z_levels = []
        self.s_levels = []

    def read(self):
        """Read the file."""
        mode = self._parse_line(int)
        if mode != 2:  # We only support Sigma-Z right now.
            raise self._error('Only SZ grids are supported.')

        z_count, s_count = self._parse_counts()

        self._read_line()  # This line always just says "Z levels".
        self.z_levels = self._parse_levels(z_count)

        self._read_line()  # This line always just says "S levels".
        self.hc, self.theta_b, self.theta_f = self._parse_line(float, float, float)

        self.s_levels = self._parse_levels(s_count)

    def _parse_counts(self):
        """
        Parse the S and Z counts.

        Returns:
            Tuple of (z_count, s_count).
        """
        all_count, z_count, _last_z_value = self._parse_line(int, int, float)

        # The last element of the Z grid overlaps with the first element of the Sigma grid. `all_count` only counts the
        # overlapping element once, which leaves `s_count` off by one.
        s_count = all_count - z_count + 1
        return z_count, s_count

    def _parse_levels(self, count: int) -> Sequence[float]:
        """
        Parse S or Z levels.

        Args:
            count: Number of levels to parse.

        Returns:
            The parsed levels.
        """
        levels = []
        while len(levels) < count:
            index, level = self._parse_line(int, float)
            level = -level  # invert level, elevation -> depth
            levels.append(level)

        return levels
