"""CbcVelocityVectorDatasetCreator class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from dataclasses import dataclass
import math
from pathlib import Path

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.api.tree import tree_util, TreeNode
from xms.constraint.modflow import DisuCalculator

# 4. Local modules
from xms.mf6.components import dis_builder
from xms.mf6.components.cbc_dataset_creator_base import CbcDatasetCreatorBase
from xms.mf6.components.xms_data import XmsData
from xms.mf6.data.flow_budget_calculator import FlowBudgetCalculator
from xms.mf6.file_io.cbc_reader import CbcReader, CbcTimeStep
from xms.mf6.file_io.grb_reader import GrbReader
from xms.mf6.geom import geom
from xms.mf6.misc import log_util


@dataclass
class VectorInputs:
    """Information needed to create velocity vectors."""
    cbc_filepath: Path | str = ''  # File path of cbc file.
    model_node: TreeNode | None = None  # Model tree node.
    porosity: float | list[float] = 0.0  # The porosity.
    grb_filepath: Path | str = ''  # binary grid file path.


def create(inputs: VectorInputs, query: Query) -> tuple[str, str] | None:
    """Creates velocity vector datasets from the cbc file.

    Args:
        inputs: Everything needed to create the vector datasets.
        query: Object for communicating with GMS.
    """
    creator = CbcVelocityVectorDatasetCreator(inputs, query)
    return creator.create()


class CbcVelocityVectorDatasetCreator(CbcDatasetCreatorBase):
    """Creates datasets from the cbc file."""
    def __init__(self, inputs: VectorInputs, query: Query):
        """Initializes the class.

        We use both the DisuCalculator and the binary grid (.grb) file. The DisuCalculator computes areas between
        cells, which we need to compute the velocity. The .grb file doesn't have that, but it does have the IA and JA
        arrays that correspond to the FLOW-JA-FACE data. The DisuCalculator requires the UGrid, which we get
        from the model because it is saved with the model, thus we are dependent on there being a model.

        Args:
            inputs: Everything needed to create the vector datasets.
            query: Object for communicating with GMS.
        """
        super().__init__(inputs.cbc_filepath, inputs.model_node, query)
        self._porosity = inputs.porosity
        self._grb_filepath = inputs.grb_filepath

        self._log = log_util.get_logger()
        self._error = None
        self._dataset_writer_vector = None
        self._dataset_writer_magnitude = None
        self._head_dset_reader = None
        self._disu_calculator = None
        self._iac = None  # Number of connections plus 1 for each cell (size cell count). NOT the same as IA from .grb
        self._ja = None  # Adjacencies (size nja)
        self._grb_ia = None
        self._grb_ja = None
        self._hwva = None  # Interface areas
        self._tops = None
        self._bottoms = None
        self._heads = None
        self._activity = None
        self._idomain = None
        self._flows = None
        self._cell_centers = None
        self._vectors = None
        self._magnitudes = None
        self._unit_vectors = None  # length = NJA
        self._unit_vectors_absolute_sum = None  # length = cell count

        # Debugging stuff
        self._debug_first = True  # So that we only write the first time step to the debug file
        self._write_debug_flows = False
        self._write_debug_vectors = False

    def create(self) -> tuple[str, str] | None:
        """Creates velocity (vector) and magnitude (scalar) datasets from data in the cbc file.

        I've tried to limit self._query to this function or self._setup().
        """
        self._write_debug_flows = Path('C:/temp/debug_mf6_flows.dbg').is_file()
        self._write_debug_vectors = Path('C:/temp/debug_mf6_vectors.dbg').is_file()

        try:  # Exceptions are raised and caught here if there are problems
            self._setup()
            self._create_datasets()
            self._add_datasets()
        except RuntimeError as e:
            self._error = ('ERROR', str(e))
            self._log.error(self._error)
        return self._error

    def _setup(self):
        """Do several setup tasks and raise an exception if there's a problem.

        Returns:
            (str): Folder where datasets should be created.
        """
        self._base_setup()
        self._setup_heads()
        self._find_grb_data()
        self._init_disu_calculator()
        self._init_porosity()
        self._init_unit_vectors()

    def _create_datasets(self):
        """Create the datasets."""
        self._dataset_writer_vector = self._init_dataset_writer('Velocity', num_components=3)
        self._dataset_writer_magnitude = self._init_dataset_writer('Velocity_Mag', num_components=1)

        time_idx = 0
        with CbcReader(self._cbc_filepath) as reader:
            for cbc_time_step in reader:
                self._log.info(f'Time step: {time_idx + 1}')
                self._get_heads_at_time_step(time_idx)
                self._get_flows_at_time_step(cbc_time_step)
                self._adjust_flows_by_source_sinks(cbc_time_step)
                self._compute_vectors_for_timestep()
                self._compute_magnitudes_for_timestep()
                self._dataset_writer_vector.append_timestep(
                    time=cbc_time_step.totim, data=self._vectors, activity=self._activity
                )
                self._dataset_writer_magnitude.append_timestep(
                    time=cbc_time_step.totim, data=self._magnitudes, activity=self._activity
                )
                time_idx += 1

        self._dataset_writer_vector.appending_finished()
        self._dataset_writer_magnitude.appending_finished()

    def _adjust_flows_by_source_sinks(self, cbc_time_step: CbcTimeStep) -> None:
        """Adjusts the flows by the sources/sinks.

        If the net flow for a cell due to the sources/sinks BCs is negative (water is being removed from the cell) then
        we need to reduce the flow through the faces by the amount that exits the cell through the BCs. If the net flow
        for a cell due to the sources/sinks BCs is positive (water is being added to the cell) then we don't do
        anything to the face flows.

        See GMS MfCcfToVectors::AdjustFlowBySourceSinksUsg().
        """
        if self._debug_first:
            self._debug_write_flows('mf6_flows_before')

        # Use FlowBudgetCalculator to determine net flows due to sources/sinks
        ia = dis_builder.ia_from_iac(self._iac)  # IA is cumulative, len(IAC) + 1, and allows indexing into JA directly
        calculator = FlowBudgetCalculator(
            cbc_filepath=self._cbc_filepath,
            selected_cells=[],
            cbc_time_step=cbc_time_step,
            ia=ia,
            ja=self._ja,
            cell_count=self._cell_count
        )
        flow_budget = calculator.calculate()
        net_flows = flow_budget['meta']['net_source_sink']

        ja_idx = 0
        for cell_idx, net_flow in enumerate(net_flows):
            adjacent_count = self._iac[cell_idx] - 1
            if net_flow < 0.0:
                sum_, sum_absolute = self._sum_cell_flows(adjacent_count, ja_idx)
                # GMS does the following, but I don't think it's necessary
                # diff = net_flow + sum_
                # if math.isclose(diff, 0.0) or diff < 0.0:
                #     self._set_flows_to_zero(adjacent_count, ja_idx)
                # elif sum_absolute != 0.0:
                if sum_absolute != 0.0:
                    self._proportionally_reduce_flows(adjacent_count, ja_idx, net_flow, sum_absolute)
            ja_idx += adjacent_count + 1

        if self._debug_first:
            self._debug_write_flows('mf6_flows_adjusted')

    def _compute_vectors_for_timestep(self) -> None:
        """Creates the vectors at every cell.

        See GMS MfCcfToVectors::ComputeVectorsForTsUsg().
        """
        self._vectors = [[0.0, 0.0, 0.0] for i in range(self._cell_count)]

        # Loop through the cells
        ja_idx = 0
        for cell_idx in range(self._cell_count):
            adjacent_count = self._iac[cell_idx] - 1
            if not self._cell_data_is_good(cell_idx):
                ja_idx += adjacent_count + 1
                continue

            # Compute all vectors from this cell to its adjacent cells
            self._compute_unit_vectors_if_necessary(cell_idx, ja_idx, adjacent_count)

            # Compute the flow vector
            for ja_idx2 in range(ja_idx + 1, ja_idx + 1 + adjacent_count):
                area = self._compute_area(ja_idx2, cell_idx)
                divisor = area * self._porosity[cell_idx]
                if divisor > 0.0:
                    factor = self._flows[ja_idx2] / divisor
                    vector_2_to_1 = self._unit_vectors[ja_idx2]
                    self._vectors[cell_idx][0] += vector_2_to_1[0] * factor
                    self._vectors[cell_idx][1] += vector_2_to_1[1] * factor
                    self._vectors[cell_idx][2] += vector_2_to_1[2] * factor

            # Get one vector for this cell by adding all the face vectors and dividing by the unit vector
            unit_vector_absolute_sum = self._unit_vectors_absolute_sum[cell_idx]
            for i in range(3):
                if unit_vector_absolute_sum[i] != 0.0:
                    self._vectors[cell_idx][i] /= unit_vector_absolute_sum[i]

            ja_idx += adjacent_count + 1

        if self._debug_first:
            self._debug_write_vectors('mf6_vectors')
            self._debug_first = False

    def _compute_unit_vectors_if_necessary(self, cell_idx: int, ja_idx: int, adjacent_count: int) -> None:
        """Compute the unit vectors on the first time step and save them for later time steps since they won't change.

        Args:
            cell_idx (int): The cell index.
            ja_idx (int): Index of the ja array.
            adjacent_count (int): Number of adjacent cells.
        """
        unit_vector_absolute_sum = self._unit_vectors_absolute_sum[cell_idx]
        if not unit_vector_absolute_sum:
            unit_vector_absolute_sum = [0] * 3  # Sum of absolute values of unit vector
            cell_center_1 = self._cell_centers[cell_idx]

            # Start at ja_idx + 1 to skip 1st adjacent cell in JA array because it's this cell's ID
            for ja_idx2 in range(ja_idx + 1, ja_idx + 1 + adjacent_count):
                # Get the unit vector in the direction of this face
                adjacent_cell_idx = self._ja[ja_idx2] - 1
                cell_center_2 = self._cell_centers[adjacent_cell_idx]
                vector_2_to_1 = np.subtract(cell_center_1, cell_center_2)  # (cell_center_1 - cell_center_2)
                geom.normalize_xyz(vector_2_to_1)
                self._unit_vectors[ja_idx2] = vector_2_to_1
                unit_vector_absolute_sum[0] += abs(vector_2_to_1[0])
                unit_vector_absolute_sum[1] += abs(vector_2_to_1[1])
                unit_vector_absolute_sum[2] += abs(vector_2_to_1[2])
            self._unit_vectors_absolute_sum[cell_idx] = unit_vector_absolute_sum

    def _set_flows_to_zero(self, adjacent_count, ja_idx):
        for ja_idx2 in range(ja_idx + 1, ja_idx + 1 + adjacent_count):
            self._flows[ja_idx2] = 0.0

    def _proportionally_reduce_flows(self, adjacent_count, ja_idx, net_flow, sum_absolute):
        for ja_idx2 in range(ja_idx + 1, ja_idx + 1 + adjacent_count):
            percent = self._flows[ja_idx2] / sum_absolute
            part = percent * net_flow
            self._flows[ja_idx2] += part

    def _sum_cell_flows(self, adjacent_count, ja_idx) -> tuple[float, float]:
        """Returns the sum of the flows, and the sum of the absolute values of the flows.

        Args:
            adjacent_count (int): Number of adjacent cells.
            ja_idx (int): Index into the ja array.

        Returns:
            (tuple[float, float]): See description.
        """
        sum_ = 0.0  # "sum" is a python built-in name, hence the trailing underscore
        sum_absolute = 0.0
        # Start at ja_idx + 1 to skip 1st adjacent cell in JA array because it's this cell's ID
        for ja_idx2 in range(ja_idx + 1, ja_idx + 1 + adjacent_count):
            sum_absolute += abs(self._flows[ja_idx2])
            sum_ += self._flows[ja_idx2]
        return sum_, sum_absolute

    def _is_flow_type_source_sink(self, flow_type: str) -> bool:
        """Returns True if the flow_type represents a source or sink.

        "The “DATA” prefix on the text identifier can be used by post-processors to recognize that the record does not
        contain a cell flow budget term." (mf6io.pdf page 283, table 35)

        Args:
            flow_type (str): The flow type from the budget file.

        Returns:
            (bool): See description.
        """
        return flow_type != 'meta' and flow_type != 'FLOW-JA-FACE' and not flow_type.startswith('DATA')

    def _get_heads_at_time_step(self, time_idx: int) -> None:
        """Get the head and flow data at the current time."""
        self._heads = self._head_dset_reader.values[time_idx]
        self._activity = self._head_dset_reader.activity[time_idx]

    def _get_flows_at_time_step(self, cbc_time_step: CbcTimeStep) -> None:
        """Get the head and flow data at the current time."""
        self._flows = list(cbc_time_step.get_data('FLOW-JA-FACE').values())
        self._add_missing_flows()

    def _add_missing_flows(self) -> None:
        """Makes flows size of nja by adding missing flow data."""
        if len(self._flows) == len(self._ja):
            return

        indexed_flows = self._index_flows()
        self._flows = [0.0] * len(self._ja)

        # Loop through the cells
        ja_idx = 0
        for cell_idx in range(self._cell_count):
            adjacent_count = self._iac[cell_idx] - 1
            self._flows[ja_idx] = 0.0  # The cell's flow to itself
            # Start at ja_idx + 1 to skip 1st adjacent cell in JA array because it's this cell's ID
            for ja_idx2 in range(ja_idx + 1, ja_idx + 1 + adjacent_count):
                adjacent_id = self._ja[ja_idx2]
                q = indexed_flows.get(cell_idx + 1, {}).get(adjacent_id)
                if q is None:
                    self._flows[ja_idx2] = 0.0
                else:
                    self._flows[ja_idx2] = q
            ja_idx += adjacent_count + 1

    def _index_flows(self) -> dict[int, dict[int, float]]:
        """Uses code in mf6io.pdf to index flows by cell and adjacent cell."""
        indexed_flows = {}
        for n in range(1, self._cell_count + 1):
            for ipos in range(self._grb_ia[n - 1] + 1, self._grb_ia[n] - 1):
                m = self._grb_ja[ipos - 1]
                q = self._flows[ipos - 1]
                if n not in indexed_flows:
                    indexed_flows[n] = {m: q}
                else:
                    indexed_flows[n][m] = q
                if m not in indexed_flows:
                    indexed_flows[m] = {n: -q}
                else:
                    indexed_flows[m][n] = -q
        return indexed_flows

    def _debug_write_flows(self, file_stem: str) -> None:
        """Writes the flows to a file for debugging."""
        if self._write_debug_flows:
            with open(f'C:/temp/debug_{file_stem}.csv', mode='w') as fp:
                for i, flow in enumerate(self._flows):
                    if self._ja[i] < 0:
                        cell_id = abs(self._ja[i])
                    adjacent_cell_id = abs(self._ja[i])
                    fp.write(f'{cell_id},{adjacent_cell_id},{flow}\n')

    def _debug_write_vectors(self, file_stem: str) -> None:
        """Writes the flows to a file for debugging."""
        if self._write_debug_vectors:
            with open(f'C:/temp/debug_{file_stem}.csv', mode='w') as fp:
                for vector in self._vectors:
                    fp.write(f'{vector[0]},{vector[1]},{vector[2]}\n')

    def _cell_data_is_good(self, cell_idx) -> bool:
        """Returns True if the cell is active, not dry, not inverted, and the porosity is valid."""
        return not (
            self._inactive_cell(cell_idx)  # noqa: W503 (line break)
            or self._dry_cell(cell_idx)  # noqa: W503 (line break)
            or self._top_below_bottom(cell_idx)  # noqa: W503 (line break)
            or self._zero_or_negative_porosity(cell_idx)  # noqa: W503 (line break)
        )  # noqa: W503 (line break)

    def _inactive_cell(self, cell_idx: int) -> bool:
        """Returns true if the cell is inactive based on the head dataset activity."""
        return not bool(self._activity[cell_idx])

    def _dry_cell(self, cell_idx: int) -> bool:
        """Returns true if the head in the cell is below the bottom elevation."""
        return self._heads[cell_idx] < self._bottoms[cell_idx]

    def _top_below_bottom(self, cell_idx: int) -> bool:
        """Returns true if the cell top is below the bottom."""
        return self._tops[cell_idx] < self._bottoms[cell_idx]

    def _zero_or_negative_porosity(self, cell_idx: int) -> bool:
        """Returns true if the cell porosity is zero or negative."""
        return self._porosity[cell_idx] <= 0.0

    def _compute_magnitudes_for_timestep(self) -> None:
        """Computes the vector magnitudes and stores them in self._magnitudes."""
        self._magnitudes = [0.0] * len(self._vectors)
        for i, vector in enumerate(self._vectors):
            self._magnitudes[i] = math.sqrt(vector[0]**2 + vector[1]**2 + vector[2]**2)

    def _compute_area(self, ja_idx: int, cell_idx: int) -> float:
        """Computes the area between adjacent cells.

        Args:
            ja_idx (int): Index into JA array which gives the adjacent cell.
            cell_idx (int): Cell index.

        Returns:
            (float): Area.
        """
        cell_idx_adj = self._ja[ja_idx] - 1  # Adjacent cell index
        if (
            math.isclose(self._bottoms[cell_idx], self._tops[cell_idx_adj])  # noqa: W503 (line break)
            or math.isclose(self._tops[cell_idx], self._bottoms[cell_idx_adj])  # noqa: W503 (line break)
            or self._heads[cell_idx] > self._tops[cell_idx]  # noqa: W503 (line break)
        ):  # noqa: W503 (line break)
            # Top or bottom face, or fully saturated. Use pre-computed area
            area = self._hwva[ja_idx]
        else:
            # Scale the area based on the saturated thickness
            # GMS does this differently and uses the width here, but we don't know the width and I think this method
            # should work
            head = min(self._heads[cell_idx], self._tops[cell_idx])
            factor = (head - self._bottoms[cell_idx]) / (self._tops[cell_idx] - self._bottoms[cell_idx])
            area = factor * self._hwva[ja_idx] if factor >= 0.0 else 0.0
        return area

    def _find_grb_data(self):
        """Reads the binary grid file and saves the data to member variables."""
        grb_data = self._read_grb_file()
        self._grb_ia = grb_data['IA'].tolist()
        self._grb_ja = grb_data['JA'].tolist()

    def _read_grb_file(self):
        """Returns data we need from the binary grid (.grb) file."""
        reader = GrbReader(self._grb_filepath, {'IA', 'JA'})
        grb_data = reader.read()
        return grb_data

    def _get_cogrid(self):
        """Returns the cogrid."""
        xms_data = XmsData(self._query)
        return xms_data.get_cogrid(self._model_node.uuid)

    def _init_disu_calculator(self) -> None:
        """Initializes the disu calculator, runs it, and saves some arrays to member variables."""
        cogrid = self._get_cogrid()
        self._tops = cogrid.get_cell_tops()
        self._bottoms = cogrid.get_cell_bottoms()
        self._cell_centers = dis_builder.get_cell_centers3d(cogrid)
        self._disu_calculator = DisuCalculator(cogrid.ugrid)
        angle = cogrid.angle if hasattr(cogrid, 'angle') else 0.0

        self._disu_calculator.calculate(self._tops, self._bottoms, self._cell_centers, angle)
        self._iac = self._disu_calculator.get_adjacent_counts()  # IAC
        self._ja = self._disu_calculator.get_adjacent_cells()  # JA
        self._hwva = self._disu_calculator.get_interface_area()  # HWVA

    def _init_porosity(self) -> None:
        """Initializes the porosity by turning a single value into a list if necessary."""
        if isinstance(self._porosity, float):
            self._porosity = [self._porosity] * self._cell_count

    def _init_unit_vectors(self) -> None:
        """Initializes the unit vector lists.

        We cache the unit vectors because they don't change between time steps.
        """
        self._unit_vectors = [None for _ in range(len(self._ja))]
        self._unit_vectors_absolute_sum = [None for _ in range(self._cell_count)]

    def _setup_heads(self) -> None:
        """Finds the head dataset and ensures that the head times match the budget times."""
        self._find_head_dset()
        self._ensure_timesteps_match()

    def _find_head_dset(self) -> None:
        """Finds the head dset and creates a dataset reader for it."""
        # Get solution folder
        solution_node = self._cbc_node.parent
        while solution_node.item_typename != 'TI_SOLUTION_FOLDER':
            solution_node = solution_node.parent

        # Find the dataset named Head
        head_ptr_node = tree_util.first_descendant_with_name(solution_node, 'Head')
        if head_ptr_node:
            head_node = tree_util.find_tree_node_by_uuid(tree_node=self._query.project_tree, uuid=head_ptr_node.uuid)
            self._head_dset_reader = self._query.item_with_uuid(head_node.uuid)
        else:
            raise RuntimeError('Could not find Head dataset.')

    def _get_head_times(self):
        """Returns the times of the head dset."""
        return self._head_dset_reader.times[()]

    def _ensure_timesteps_match(self) -> None:
        """Ensures that the head dset timesteps match the budget file timesteps."""
        cbc_times = np.array(self._get_cbc_times())
        head_times = self._get_head_times()
        if len(cbc_times) != len(head_times) or not np.allclose(cbc_times, head_times):
            raise RuntimeError('Times of head dataset do not match times in budget file.')

    def _get_cbc_times(self) -> list[float]:
        """Returns the times from the .cbc file."""
        times = []
        with CbcReader(self._cbc_filepath) as reader:
            for cbc_time_step in reader:
                times.append(cbc_time_step.totim)
        return times

    def get_error(self) -> tuple[str, str] | None:
        """Returns the error, if any."""
        return self._error

    def _add_datasets(self):
        """Uses Query to tell GMS to add the datasets to the tree."""
        self._query.add_dataset(self._dataset_writer_vector, folder_path=self._tree_folder)
        self._query.add_dataset(self._dataset_writer_magnitude, folder_path=self._tree_folder)
