"""GrbReader class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from pathlib import Path
import struct

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules

# 4. Local modules
from xms.mf6.misc import log_util


def cell_count_from_grb_data(grb_data) -> int:
    """Returns the number of cells from data in the binary grid (.grb) file."""
    ncells = grb_data.get('NCELLS')  # DIS, DISV
    if ncells is None:
        return grb_data.get('NODES')  # DISU
    return ncells


class GrbReader:
    """Reads the cbc file and calculates the flow budget of the selected cells."""

    HEADERS_COUNT = 4  # Number of header lines in the grb file
    NXT_LINE = 2  # Headers line containing thenumber of definitions
    LENTXT_LINE = 3  # Headers line containing thenumber of definitions
    HEADER_VALUE_INDEX = 1  # Index of the word in the header line that has the value
    NAME_WORD = 0  # Index of word in definition line defining the data name
    DATA_TYPE_WORD = 1  # Index of word in definition line defining the data type

    def __init__(self, grb_file: Path, requested_data: set | list) -> None:
        """Initializes the class.

        Args:
            grb_file (Path): binary grid file path.
            requested_data (list[str]): List of data to return (e.g. ['NJA', 'IA', 'JA'])
        """
        self._grb_file = grb_file
        self._requested_data_set = set(requested_data)
        self._fp = None  # File pointer
        self._headers = []  # Header lines
        self._definitions = []  # Definition lines

    def read(self):
        """Returns data from the binary grid (.grb) file.

        Returns:
            (dict): Dict of the data.
        """
        data = None
        try:
            with self._grb_file.open('rb') as self._fp:
                self._read_headers()
                self._read_definitions()
                data = self._read_data()
        except RuntimeError as e:
            log = log_util.get_logger()
            log.error(f'Error reading .grb file: {str(e)}')
        return data

    def _ntxt(self) -> int:
        """Returns the value of NTXT in the headers."""
        words = self._headers[GrbReader.NXT_LINE].split()
        return int(words[GrbReader.HEADER_VALUE_INDEX])

    def _lentxt(self) -> int:
        """Returns the value of LENTXT in the headers."""
        words = self._headers[GrbReader.LENTXT_LINE].split()
        return int(words[GrbReader.HEADER_VALUE_INDEX])

    def _read_headers(self) -> None:
        fmt = '<50s'
        size = struct.calcsize(fmt)
        for _i in range(GrbReader.HEADERS_COUNT):
            data = self._fp.read(size)
            line = struct.unpack(fmt, data)[0].decode('utf-8').strip()
            self._headers.append(line)

    def _read_definitions(self) -> None:
        """Reads the definition lines and returns them as a dict."""
        ntxt = self._ntxt()
        lentxt = self._lentxt()
        fmt = f'<{lentxt}s'
        size = struct.calcsize(fmt)
        i = 0
        while i < ntxt:
            data = self._fp.read(size)
            line = struct.unpack(fmt, data)[0].decode('utf-8').strip()
            if line.startswith('#'):
                continue
            self._definitions.append(line)
            i += 1

    def _read_data(self):
        """Returns some arrays from the binary grid (.grb) file."""
        data_dict = {}
        found_data_set = set()
        for definition in self._definitions:
            data_name, dtype, size = self._name_dtype_size(definition)
            record = self._read_record(dtype, size)
            if data_name in self._requested_data_set:
                data_dict[data_name] = record
                found_data_set.add(data_name)
                if found_data_set == self._requested_data_set:
                    break
        return data_dict

    def _name_dtype_size(self, definition) -> tuple[str, str, int]:
        """Returns the data name, data type, and number of records to read given the definition.

        Args:
            definition (str): A definition line.

        Returns:
            tuple(str, str, int): See description.
        """
        words = definition.split()
        data_name = words[GrbReader.NAME_WORD]
        dtype = self._record_dtype(words[GrbReader.DATA_TYPE_WORD])
        size = self._record_size(words)
        return data_name, dtype, size

    def _record_size(self, words) -> int:
        """Returns the size of the record.

        Args:
            words (list[str]): The definition line broken into individual words.
        """
        ndim_idx = words.index('NDIM')
        if ndim_idx < 0:
            raise RuntimeError('Error reading definition in .grb file.')
        ndim = int(words[ndim_idx + 1])
        if ndim == 0:
            return 1

        dims = [int(word) for word in words[ndim_idx + 2:ndim_idx + 2 + ndim]]
        size = 1
        for dim in dims:
            size *= dim
        return size

    def _record_dtype(self, type_word):
        """Returns the size of the record.

        Args:
            type_word (str): The type word from the definition line.
        """
        match type_word:
            case 'INTEGER':
                dtype = 'int'
            case 'DOUBLE':
                dtype = 'float'
            case _:
                raise RuntimeError(f'Data type not recognized {type_word}')
        return dtype

    def _read_record(self, dtype, size):
        """Reads and returns the record data."""
        data = np.fromfile(self._fp, dtype=dtype, count=size)
        if size == 1:
            data = data[0]
        return data
