"""FlowBudgetCalculator class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import math
from pathlib import Path

# 2. Third party modules

# 3. Aquaveo modules

# 4. Local modules
from xms.mf6.file_io import grb_reader
from xms.mf6.file_io.cbc_reader import CbcData, CbcReader, CbcTimeStep
from xms.mf6.file_io.grb_reader import GrbReader


class FlowBudgetCalculator:
    """Reads the cbc file and calculates the flow budget of the selected cells, including totals and summary."""
    def __init__(
        self,
        cbc_filepath: Path | str,
        selected_cells: list[int],
        time: float = None,
        cbc_time_step: CbcTimeStep = None,
        grb_filepath: Path = None,
        ia=None,
        ja=None,
        cell_count: int = None
    ) -> None:
        """Initializes the class.

        Pass grb_file, or ia, ja, and cell_count, but not both. Pass time, or cbc_time_step, but not both.

        We use the IA array (as found in the .grb file) and not the IAC array (from DISU or DisuCalculator) because
        it allows us to index into the JA array directly.

        Args:
            cbc_filepath: cbc file path.
            selected_cells (list[int]): The selected cells indices (0-based). Passing [] is the same as all the cells.
            time (float): Time of the active timestep from GMS.
            cbc_time_step (CbcTimeStep): Time step of the .cbc file.
            grb_filepath (Path): binary grid file path.
            ia: IA array (different from DISU IAC) has cumulative number of connections plus 1 for each cell
            ja: JA array which has adjacencies (size nja)
            cell_count (int): Number of grid cells.
        """
        self._raise_exception_if_invalid_arguments(grb_filepath, ia, ja, cell_count, time, cbc_time_step)

        self._cbc_filepath = Path(cbc_filepath)
        self._selected_cells = set(selected_cells)
        self._time = time
        self._cbc_time_step = cbc_time_step
        self._grb_filepath = grb_filepath
        self._ia = ia
        self._ja = ja
        self._cell_count = cell_count

        self._closest_time = 0
        self._total_ss_in = 0.0
        self._total_ss_out = 0.0
        self._total_zone_in = 0.0
        self._total_zone_out = 0.0
        self._net_source_sink = None  # Net in + out for every cell for sources/sinks

    def _raise_exception_if_invalid_arguments(self, grb_file, ia, ja, cell_count, time, cbc_time_step):
        """Raises RuntimeError if arguments don't make sense."""
        arg_count = int(ia is not None) + int(ja is not None) + int(cell_count is not None)
        if arg_count != 0 and arg_count != 3:
            raise RuntimeError('ia, ja, and cell_count, must all be passed if any of them are.')

        if grb_file is not None and arg_count > 0:
            raise RuntimeError('Pass either grb_file or ia, ja, and cell_count, but not both.')

        if time is not None and cbc_time_step is not None:
            raise RuntimeError('Pass either time or cbc_time_step, but not both.')

    def calculate(self):
        """Reads the cbc file and calculates the flow budget of the selected cells.

        Returns:
            (dict): The flow budget.
        """
        # Get data from the binary grid (.grb) file if it wasn't supplied
        if self._ia is None and not self._data_from_grb_file():
            return None

        flow_budget = {'meta': {'selected_cells_count': len(self._selected_cells)}}
        self._select_all_cells_if_none_selected()

        cbc_time_step, close_file, fp = self._get_cbc_time_step()

        flow_budget['meta']['time'] = cbc_time_step.totim
        self._net_source_sink = [0.0] * self._cell_count
        flow_data = None
        for cbc_data in cbc_time_step.cbc_data:
            if cbc_data.header.text == 'FLOW-JA-FACE':
                flow_data = self._do_intercell_flows(cbc_data)
            elif self._is_source_sink(cbc_data):
                flow_data = self._do_sources_sinks(cbc_data)

            if flow_data:
                flow_budget[cbc_data.header.text] = flow_data

        if close_file:
            fp.close()

        self._summarize(flow_budget)
        return flow_budget

    def _get_cbc_time_step(self):
        """If we weren't given the cbc time step, open the file and find it."""
        if not self._cbc_time_step:
            fp = open(self._cbc_filepath, 'rb')
            return self._find_closest_time_step(fp), True, fp
        return self._cbc_time_step, False, None

    def _is_source_sink(self, cbc_data: CbcData) -> bool:
        """Returns True if the data is a cell flow budget term that we should include in the budget.

        "The “DATA” prefix on the text identifier can be used by post-processors to recognize that the record does not
        contain a cell flow budget term." (mf6io.pdf page 283, table 35)

        Args:
            cbc_data (CbcData): Data from the cbc file.

        Returns:
            (bool): See description.
        """
        return cbc_data.records_count() <= self._cell_count and not cbc_data.header.text.startswith('DATA')

    def _find_closest_time_step(self, fp) -> CbcTimeStep:
        """Finds the time step in the file closest to the given time."""
        ts0 = None  # previous
        ts1 = None  # current
        # Don't use the context manager because we don't want to close the file until later
        reader = CbcReader(fp=fp)
        cbc_time_step = reader.read_next_time_step()
        while cbc_time_step:
            ts0 = ts1
            ts1 = cbc_time_step
            if self._time is None:  # Hack to return the first ts. Needed until we can get the time from GMS
                break
            if math.isclose(ts1.totim, self._time) or ts1.totim > self._time:
                break
            cbc_time_step = reader.read_next_time_step()

        # Return the time step closest to the active timestep time
        if not ts0:
            return ts1
        if abs(ts0.totim - self._time) < abs(ts1.totim - self._time):
            return ts0
        else:
            return ts1

    def _is_closest_time(self, cbc_data: CbcData, closest_time: float) -> bool:
        """Returns true if the data is for the active timestep.

        Args:
            cbc_data (CbcData): Data from the cbc file.
            closest_time (float): Time in the file closest to the active timestep time.

        Returns:
            (bool): See description.
        """
        return math.isclose(cbc_data.header.totim, closest_time)

    def _select_all_cells_if_none_selected(self) -> None:
        """Sets the selected cells to all cells if none are selected."""
        if not self._selected_cells:
            self._selected_cells = set([i for i in range(0, self._cell_count)])

    def _summarize(self, flow_budget: dict) -> None:
        """Add summary data.

        Args:
            flow_budget (dict): Dict of flow budget data.
        """
        diff_ss = self._total_ss_in + self._total_ss_out
        diff_zone = self._total_zone_in + self._total_zone_out
        total_in = self._total_ss_in + self._total_zone_in
        total_out = self._total_ss_out + self._total_zone_out
        total = total_in + total_out
        sum_ = total_in - total_out
        if math.isclose(sum_, 0.0):
            percent_difference = float('inf')
        else:
            percent_difference = 100 * total / (sum_ / 2.0)

        totals = {
            'Sources/Sinks': {
                'total_in': self._total_ss_in,
                'total_out': self._total_ss_out
            },
            'Selected Zone': {
                'total_in': self._total_zone_in,
                'total_out': self._total_zone_out
            },
            'total_in': total_in,
            'total_out': total_out,
        }
        summary = {
            'In - Out, Sources/Sinks': diff_ss,
            'In - Out, Selected Zone': diff_zone,
            'Total': total,
            '% difference': percent_difference
        }
        flow_budget['meta']['totals'] = totals
        flow_budget['meta']['summary'] = summary
        flow_budget['meta']['net_source_sink'] = self._net_source_sink

    def _do_sources_sinks(self, cbc_data):
        """Sum flows into and out of sources and sinks.

        Args:
            cbc_data (CbcData): The Cbc data

        Returns:
            dict of flows in/out
        """
        into = 0.0
        out = 0.0
        values = cbc_data.values()
        cell_ids = cbc_data.cell_ids()
        if not cell_ids:
            cell_ids = [i + 1 for i in range(len(values))]
        if values and cell_ids and len(values) == len(cell_ids):
            for value, cell_id in zip(values, cell_ids):
                cell_idx = cell_id - 1
                if cell_idx in self._selected_cells and not math.isclose(value, CbcReader.NO_DATA):
                    self._net_source_sink[cell_idx] += value
                    if value > 0:
                        into += value
                    elif value < 0:
                        out += value

        if into == 0.0 and out == 0.0:
            return None
        package_data = {'in': into, 'out': out}
        self._total_ss_in += into
        self._total_ss_out += out
        return package_data

    def _do_intercell_flows(self, cbc_data):
        """Handle intercell flows between selected and adjacent cells.

        Args:
            cbc_data (CbcData): Cbc data

        Returns:
            dict of flows in/out of the selected cells
        """
        ia = self._ia
        ja = self._ja
        flow_ja_face = cbc_data.values()
        for cell_idx in self._selected_cells:
            start = ia[cell_idx]
            stop = start + self._connections_count(cell_idx, ia)
            for ipos in range(start, stop):
                m = ja[ipos]  # Adjacent cell id (1-based)
                if m - 1 not in self._selected_cells:
                    q = flow_ja_face[ipos]
                    # Positive flow is coming in, negative is going out
                    if q > 0.0:
                        self._total_zone_in += abs(q)
                    elif q < 0.0:
                        self._total_zone_out += q
        package_data = {'in': self._total_zone_in, 'out': self._total_zone_out}
        return package_data

    def _connections_count(self, cell_idx, ia) -> int:
        """Returns the number of connections to cell.

        Args:
            cell_idx (int): Cell index.
            ia: The IA array.

        Returns:
            (int): See description.
        """
        return max(0, ia[cell_idx + 1] - ia[cell_idx] - 1)  # If negative, make it 0

    def _read_grb_file(self):
        """Returns data we need from the binary grid (.grb) file."""
        if not self._grb_filepath.is_file():
            return None
        reader = GrbReader(self._grb_filepath, {'NCELLS', 'NODES', 'NJA', 'IA', 'JA'})
        grb_data = reader.read()
        return grb_data

    def _data_from_grb_file(self) -> bool:
        """Gets data we need from the .grb file and returns True if successful."""
        grb_data = self._read_grb_file()
        if not grb_data:
            return False
        self._ia = grb_data['IA']  # 'IA' from .grb file is cumulative
        self._ja = grb_data['JA']
        self._cell_count = grb_reader.cell_count_from_grb_data(grb_data)
        return True
