"""CbcReader and CbcData classes."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import collections
from dataclasses import dataclass
from pathlib import Path
import struct

# 2. Third party modules

# 3. Aquaveo modules

# 4. Local modules


def _rec_10_fmt(ndat: int, include_less_than: bool) -> str:
    """Returns the format string for record 10."""
    if include_less_than:
        return f'<ii{"d" * ndat}'
    else:
        return f'ii{"d" * ndat}'


@dataclass
class CbcHeader:
    """Records 1 and 2.

    Record 1: KSTP,KPER,TEXT,NDIM1,NDIM2,-NDIM3
    Record 2: IMETH,DELT,PERTIM,TOTIM
    """
    kstp: int = 0  # time step number
    kper: int = 0  # stress period number
    text: str = ''  # (character*16) indicating the flow type
    ndim1: int = 0  # size of first dimension
    ndim2: int = 0  # size of second dimension
    ndim3: int = 0  # size of third dimension (this one is always negative)
    imeth: int = 0  # code that specifies the form of the remaining data
    delt: float = 0.0  # length of the timestep
    pertim: float = 0.0  # time value for the current stress period
    totim: float = 0.0  # total simulation time

    @staticmethod
    def from_tuple(t):
        """Returns a CbcHeader class created from the tuple containing records 1 and 2.

        Args:
            t (tuple): The tuple, containing records 1 and 2.
        """
        return CbcHeader(t[0], t[1], t[2].decode('utf-8').strip(), t[3], t[4], t[5], t[6], t[7], t[8], t[9])


class CbcData:
    """A record of a chunk of data read from the CBC file."""
    def __init__(
        self,
        fp,
        header: CbcHeader,
        txt2id2: str = '',
        ndat: int = 0,
        nlist: int = 0,
        offset: int = 0,
        pos: int = 0,
        nbytes: int = 0
    ) -> None:
        """Initializes the class.

        Args:
            fp: File pointer
            header (CbcHeader): The header, consisting of records 1 and 2.
            txt2id2 (str): If IMETH=6, "TXT2ID2, the package or model name for information in ID2"
            ndat (int): For IMETH=1, "number of columns in DATA2D, which is the number of auxiliary values plus 1"
            nlist (int): For IMETH=6, "the size of the list"
            offset (int): if IMETH=1, 0. If IMETH=6, 2.
            pos (int): File position.
            nbytes (int): Number of bytes we'll need to read when actually reading the data.
            data: The data values.
        """
        self._fp = fp
        self.header = header
        self.txt2id2 = txt2id2
        self.ndat = ndat
        self.nlist = nlist
        self.nbytes = nbytes
        self.offset = offset
        self.pos = pos
        self._data = None

    def records_count(self):
        """Returns the number of records in the data.

        Returns:
            (int): See description
        """
        if self._data is None:
            self._read_data()
        return int(len(self._data) / (self.ndat + self.offset))

    def dataset_values(self, cell_count: int):
        """Returns the data as a numpy array size of cell_count, with nodata values set to CbcReader.NO_DATA.

        Args:
            cell_count (int): Grid cell count.

        Returns:
            (list[float|int]): The values.
        """
        if self._data is None:
            self._read_data()
        if self.offset:
            # Reading a smaller list of values but returning it as a list, size of cell_count
            values = [CbcReader.NO_DATA] * cell_count
            for i in range(0, len(self._data), self.ndat + self.offset):
                if values[self._data[i] - 1] == CbcReader.NO_DATA:
                    values[self._data[i] - 1] = self._data[i + self.offset]
                else:
                    values[self._data[i] - 1] += self._data[i + self.offset]
        else:
            values = self._data
        return values

    def values(self) -> list[float | int]:
        """Returns the data as a numpy array size of who knows.

        Returns:
            The list of values.
        """
        if self._data is None:
            self._read_data()
        if self.offset:
            values = [self._data[i + self.offset] for i in range(0, len(self._data), self.ndat + self.offset)]
            return values
        else:
            return self._data

    def cell_ids(self) -> list[int] | None:
        """Returns the cell ids if we can.

        Returns:
            The list of cell ids.
        """
        if self._data is None:
            self._read_data()
        if self.offset:
            cell_ids = [self._data[i] for i in range(0, len(self._data), self.ndat + self.offset)]
            return cell_ids
        else:
            return None

    def _read_data(self):
        """Reads the data from disk."""
        if self.header.imeth == 1:
            self._read_imeth1()
        elif self.header.imeth == 6:
            self._read_imeth6()

    def _read_imeth1(self) -> None:
        """Reads and stores the data stored using IMETH=1, a 1D array."""
        was_closed = self._open_file_if_necessary()

        nvals = self.header.ndim1 * self.header.ndim2 * -self.header.ndim3
        prev_pos = self._fp.tell()  # Save previous position
        self._fp.seek(self.pos)
        chunk = self._fp.read(self.nbytes)
        fmt = f'<{"d" * nvals}'
        self._data = struct.unpack(fmt, chunk)
        self._fp.seek(prev_pos)  # Restore previous position

        self._close_file_if_necessary(was_closed)

    def _read_imeth6(self) -> None:
        """Reads and stores the data stored using IMETH=6, a list of data."""
        if self.nlist:
            was_closed = self._open_file_if_necessary()

            prev_pos = self._fp.tell()  # Save previous position
            self._fp.seek(self.pos)
            chunk = self._fp.read(self.nbytes)
            fmt = _rec_10_fmt(self.ndat, False) * self.nlist
            self._data = struct.unpack(fmt, chunk)
            self._fp.seek(prev_pos)  # Restore previous position
            self._close_file_if_necessary(was_closed)
        else:
            self._data = []

    def _open_file_if_necessary(self) -> bool:
        """Returns True if the file was closed and we opened it."""
        if self._fp.closed:
            self._fp.open(mode='rb')
            return True
        return False

    def _close_file_if_necessary(self, was_closed: bool) -> None:
        """Closes the file if it was closed before."""
        if was_closed:
            self._fp.close()


class CbcTimeStep:
    """All the CbcData for a particular time."""
    def __init__(self, kstp: int, kper: int, pertim: float, totim: float):
        """Initializes the class."""
        self.kstp = kstp
        self.kper = kper
        self.pertim = pertim
        self.totim = totim
        self.cbc_data = []

    def add_data(self, cbc_data: CbcData):
        """Adds data to the time step."""
        self.cbc_data.append(cbc_data)

    def get_data(self, flow_type: str) -> CbcData:
        """Returns the data for the given flow type.

        Args:
            flow_type (str):

        Returns:
            (CbcData): The data
        """
        for data in self.cbc_data:
            if data.header.text == flow_type:
                return data
        return None


class CbcReader:
    """Designed as a context manager, reads the cbc file.

    See https://stackoverflow.com/questions/56918142/use-with-in-iter second answer
    """

    NO_DATA = 1.0E30  # mf6io.pdf pg 167. Dry values are -1.0E30

    def __init__(self, cbc_filepath: Path | str = None, fp=None) -> None:
        """Initializes the class.

        If using in a context manager, pass cbc_filepath and not fp. Otherwise pass fp and not cbc_filepath.

        Args:
            cbc_filepath (Path|str): The file path.
            fp: The opened file pointer. You are responsible for closing it.
        """
        if cbc_filepath is not None and fp is not None:
            raise RuntimeError('Pass either the cbc_filepath (the file path), or fp (the opened file), but not both.')

        self._cbc_path = Path(cbc_filepath) if cbc_filepath and isinstance(cbc_filepath, str) else cbc_filepath
        self._fp = fp
        self.size_double = 8
        self.header_fmt = '<ii16siiiiddd'
        self.header_size = struct.calcsize(self.header_fmt)
        self.Rec3To7Tuple = collections.namedtuple('Rec3To7Tuple', 'txt1id1 txt2id1 txt1id2 txt2id2 ndat')
        self.rec_3_to_7_fmt = '<16s16s16s16si'
        self.rec_3_to_7_size = struct.calcsize(self.rec_3_to_7_fmt)
        self._next_header = None  # The header for the next time step

    def __enter__(self):
        """Called when entering the context manager."""
        if not self._fp:
            self._fp = open(self._cbc_path, mode='rb')
        return self

    def __iter__(self):
        """Iterates."""
        return self

    def __exit__(self, exc_type, exc_value, exc_tb):
        """Called when leaving the context manager.

        Args:
            exc_type: exception type
            exc_value: exception value
            exc_tb: traceback
        """
        if self._cbc_path:  # Only close it if they gave us the file path and we opened it
            self._fp.close()

    def __next__(self) -> CbcTimeStep:
        """Reads the next chunk of data or raises StopIteration if no more data."""
        cbc_time_step = self.read_next_time_step()
        if not cbc_time_step:
            raise StopIteration  # We've hit the end of the file
        return cbc_time_step

    def _skim_imeth1(self, header: CbcHeader) -> CbcData:
        """Skims over data stored using IMETH=1, a 1D array, and saves the file position so it can be read later.

        Args:
            header (CbcHeader): The header.

        Returns:
            (CbcData): The data.
        """
        nvals = header.ndim1 * header.ndim2 * -header.ndim3
        nbytes = self.size_double * nvals
        cbc_data = CbcData(fp=self._fp, header=header, txt2id2='', ndat=1, offset=0, pos=self._fp.tell(), nbytes=nbytes)
        self._fp.seek(nbytes, 1)  # Skip over the actual data. '1' means 'from current file position'
        return cbc_data

    def _skim_imeth6(self, header) -> CbcData:
        """Skims over data stored using IMETH=6, a list of data, and saves the file position so it can be read later.

        Args:
            header (CbcHeader): The header.

        Returns:
            (CbcData): The data.
        """
        # Read records 3 through 7
        chunk = self._fp.read(self.rec_3_to_7_size)
        rec_3_to_7 = self.Rec3To7Tuple._make(struct.unpack(self.rec_3_to_7_fmt, chunk))
        txt2id2 = rec_3_to_7.txt2id2.decode('utf-8').strip()

        # Read records 8 through 9
        rec_8_to_9_fmt = f'<{"16s" * (rec_3_to_7.ndat - 1)}i'
        size = struct.calcsize(rec_8_to_9_fmt)
        chunk = self._fp.read(size)
        rec_8_to_9 = struct.unpack(rec_8_to_9_fmt, chunk)
        nlist = rec_8_to_9[-1]
        # aux = [thing.decode('utf-8').strip() for thing in rec_8_to_9[:-1]]
        # We're not doing anything with aux for now

        # Compute nbytes
        fmt = _rec_10_fmt(rec_3_to_7.ndat, True)
        size_row = struct.calcsize(fmt)
        nbytes = size_row * nlist

        # Create the CbcData
        pos = self._fp.tell()
        cbc_data = CbcData(
            fp=self._fp,
            header=header,
            txt2id2=txt2id2,
            ndat=rec_3_to_7.ndat,
            offset=2,
            pos=pos,
            nlist=nlist,
            nbytes=nbytes
        )
        self._fp.seek(nbytes, 1)  # Skip over the actual data. '1' means 'from current file position'
        return cbc_data

    def _read_header(self) -> CbcHeader:
        """Reads Record 1 and 2 and returns them in a CbcHeader.

        Returns:
            (CbcHeader | None)
        """
        chunk = self._fp.read(self.header_size)
        if chunk:
            return CbcHeader.from_tuple(struct.unpack(self.header_fmt, chunk))
        else:
            return None

    def _skim_data(self, header: CbcHeader) -> CbcData:
        """Reads over data and saves the file position so it can be read later.

        Args:
            header (CbcHeader): The header, consisting of records 1 and 2.

        Returns:
            (CbcData): The data.
        """
        if header.imeth == 1:
            return self._skim_imeth1(header)
        elif header.imeth == 6:
            return self._skim_imeth6(header)

    # def read_next(self) -> CbcData:
    #     """Reads the next chunk of data.
    #
    #     Returns:
    #         (CbcData): The data.
    #     """
    #     header = self._read_header()
    #     if header:
    #         return self._skim_data(header)
    #     return None

    def read_next_time_step(self) -> CbcTimeStep:
        """Reads all the data for the next time.

        Returns:
            (CbcTimeStep): The data for the time step.
        """
        # Get the header (we may have already read it)
        header = self._next_header if self._next_header else self._read_header()
        if not header:
            return None

        cbc_timestep = CbcTimeStep(header.kstp, header.kper, header.pertim, header.totim)

        # Skim and store data until we get to the next time step, or the end of the file
        while True:
            cbc_timestep.add_data(self._skim_data(header))
            header = self._read_header()
            if not header:
                self._next_header = None
                break
            elif cbc_timestep.kstp != header.kstp or cbc_timestep.kper != header.kper:
                self._next_header = header
                break
        return cbc_timestep
