"""A Qt model for editing the sediment grain size distribution."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules
from PySide2.QtCore import QModelIndex, Qt

# 3. Aquaveo modules
from xms.guipy.models.rename_model import RenameModel

# 4. Local modules


class DistributionTableModel(RenameModel):
    """A class to filter out columns for the plot."""
    def __init__(self, parent=None):
        """Initializes the filter model.

        This model assumes that the source model has a pandas DataFrame with columns: 'layer_id', 'constituent_id',
        'percent', 'ID', 'NAME', 'GRAIN_DIAMETER'. The DataFrame should be sorted by 'GRAIN_DIAMETER'.

        Args:
            parent (Something derived from :obj:`QObject`): The parent object.
        """
        self._NAME_SOURCE_COLUMN = 4
        self._SIZE_SOURCE_COLUMN = 5
        self._FRACTION_SOURCE_COLUMN = 2
        self._NAME_COLUMN = 0
        self._SIZE_COLUMN = 1
        self._FRACTION_COLUMN = 2
        super().__init__(['Constituent\nname', 'Grain\ndiameter (m)', 'Distribution\nfraction'], parent)

    def filterAcceptsColumn(self, source_column: int, source_parent: QModelIndex) -> bool:  # noqa: N802
        """Filters out columns the plot doesn't need.

        Args:
            source_column (int): The column index in the source model.
            source_parent (QModelIndex): The parent index of the view.

        Returns:
            True if the model should keep the column, false if the column should be hidden.
        """
        return source_column in [self._NAME_SOURCE_COLUMN, self._FRACTION_SOURCE_COLUMN, self._SIZE_SOURCE_COLUMN]

    def mapFromSource(self, source_index: QModelIndex) -> QModelIndex:  # noqa: N802
        """Maps the index of the source model to this model.

        This is for swapping the name and distribution fraction columns in the table.

        Args:
            source_index (QModelIndex): The index in the source model that needs to be mapped.

        Returns:
            A QModelIndex of the source model for the given index of our model.
        """
        proxy_column = source_index.column()
        if proxy_column == self._NAME_SOURCE_COLUMN:
            proxy_column = self._NAME_COLUMN
        elif proxy_column == self._FRACTION_SOURCE_COLUMN:
            proxy_column = self._FRACTION_COLUMN
        elif proxy_column == self._SIZE_SOURCE_COLUMN:
            proxy_column = self._SIZE_COLUMN
        return self.index(source_index.row(), proxy_column, source_index.parent())

    def mapToSource(self, proxy_index: QModelIndex) -> QModelIndex:  # noqa: N802
        """Maps the index of this model to the source model.

        This is for swapping the name and distribution fraction columns in the table.

        Args:
            proxy_index (QModelIndex): The index in this model that needs to be mapped.

        Returns:
            A QModelIndex of the source model for the given index of our model.
        """
        source_column = proxy_index.column()
        if source_column == self._NAME_COLUMN:
            source_column = self._NAME_SOURCE_COLUMN
        elif source_column == self._FRACTION_COLUMN:
            source_column = self._FRACTION_SOURCE_COLUMN
        elif source_column == self._SIZE_COLUMN:
            source_column = self._SIZE_SOURCE_COLUMN
        return self.sourceModel().index(proxy_index.row(), source_column, proxy_index.parent())

    def flags(self, index: QModelIndex):
        """Decides flags for things like enabled state.

        Args:
            index (QModelIndex): The index in the

        Returns:
            Item flags for the model index.
        """
        local_flags = super(DistributionTableModel, self).flags(index)
        if index.column() in [self._NAME_COLUMN, self._SIZE_COLUMN]:
            local_flags = local_flags & ~Qt.ItemIsEditable
            local_flags = local_flags & ~Qt.ItemIsEnabled
        return local_flags

    def get_distribution_fraction_sum(self):
        """Returns the sum of the distribution fraction values.

        Returns:
            Returns a float value representing the sum of the distribution fraction values.
        """
        return self.sourceModel().data_frame.iloc[:, self._FRACTION_SOURCE_COLUMN].sum(axis=0)

    def normalize_distribution_fraction(self):
        """Normalizes the distribution fraction values."""
        total = self.sourceModel().data_frame.iloc[:, self._FRACTION_SOURCE_COLUMN].sum()
        self.sourceModel().data_frame = \
            self.sourceModel().data_frame.assign(fraction=lambda x: round(x.fraction / total, 10))
        top_index = self.sourceModel().index(0, self._FRACTION_SOURCE_COLUMN)
        bottom_index = self.sourceModel().index(self.sourceModel().rowCount() - 1, self._FRACTION_SOURCE_COLUMN)
        self.sourceModel().dataChanged.emit(top_index, bottom_index)
