"""QSqlDatabase queries."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from contextlib import contextmanager
from typing import Any

# 2. Third party modules
from PySide2.QtSql import QSqlDatabase, QSqlError, QSqlQuery
from PySide2.QtWidgets import QWidget

# 3. Aquaveo modules
from xms.guipy.dialogs import message_box

# 4. Local modules
from xms.mf6.data.grid_info import GridInfo
from xms.mf6.file_io import database


class QueryError(Exception):
    """Custom exception for QSqlError."""
    def __init__(self, qsql_error: QSqlError):
        """Initializer.

        Args:
            qsql_error: The error.
        """
        super().__init__('')
        self.qsql_error = qsql_error


@contextmanager
def _q(db: QSqlDatabase):
    """Context manager for the query so we always call finish()."""
    q = QSqlQuery(db)
    try:
        yield q
    finally:
        q.finish()


class QxQueries:
    """A place to put queries that work on a QSqlDatabase."""
    def __init__(self, db: QSqlDatabase, win_cont: QWidget | None, show_error: bool = True):
        """Initializer.

        Args:
            db: The database.
            win_cont: The window container.
            show_error: If True, a message box is shown immediately if there's an error.
        """
        self._win_cont = win_cont
        self._db = db
        self._show_error = show_error  # Show a message box with the error. If False, only adds error to self._errors.
        self._errors = []

    @property
    def errors(self) -> list[str]:
        """Return the list of error strings."""
        return self._errors

    def clear_errors(self) -> None:
        """Clears all error messages."""
        self._errors = []

    def _handle_error(self, error: QueryError, msg: str) -> None:
        """Add error to self._errors and show the error in a message_box if self._show_error is True."""
        self._errors.append(msg)
        if self._show_error:
            message_box.message_with_ok(self._win_cont, message=msg, icon='Critical', details=get_error_str(error))

    def table_exists(self, table: str) -> bool:
        """Return True if the table exists.

        Args:
            table: Name of the table.

        Returns:
            See description.
        """
        return table in self._db.tables()

    def delete_period(self, period: int) -> None:
        """Deletes the stress period rows from the database.

        Args:
            period: The period.
        """
        with _q(self._db) as q:
            try:
                q.prepare('DELETE FROM data WHERE PERIOD=?')
                q.bindValue(0, period)
                exec_raises(q)
            except QueryError as e:
                self._handle_error(e, 'Unexpected error deleting period.')

    def delete_rows(self, ids: list[int]) -> None:
        """Delete rows with the given ids.

        Args:
            ids: List of id values.
        """
        with _q(self._db) as q:
            try:
                q.prepare(f'DELETE FROM data WHERE id IN ({", ".join(map(str, ids))})')
                exec_raises(q)
            except QueryError as e:
                self._handle_error(e, 'Unexpected error deleting rows.')

    def delete_all_table_rows(self, table: str) -> None:
        """Deletes all the rows from the database.

        Args:
            table: table name.
        """
        with _q(self._db) as q:
            try:
                q.prepare(f'DELETE FROM {table}')
                exec_raises(q)
            except QueryError as e:
                self._handle_error(e, 'Unexpected error deleting table rows.')

    def update_column_values(self, column_name: str, ids: list[int], new_values: list[Any]) -> None:
        """Update values in a column.

        Args:
            column_name: Name of the column
            ids: List of ids for 'id' column
            new_values: List of new values
        """
        update_query_str = f'UPDATE data SET {column_name} = ? WHERE id=?'
        with _q(self._db) as q:
            try:
                q.prepare(update_query_str)
                q.addBindValue(new_values)
                q.addBindValue(ids)
                if q.execBatch() is False:
                    raise QueryError(q.lastError())
            except QueryError as e:
                self._handle_error(e, 'Unexpected error updating column values.')

    def drop_table(self, table: str) -> None:
        """Replace the old data table with a new one and return the query.

        Args:
            table: Name of the table.

        Returns:
            See description.
        """
        if not self.table_exists(table):
            return

        with _q(self._db) as q:
            try:
                exec_raises(q, f'DROP TABLE {table}')
            except QueryError as e:
                self._handle_error(e, 'Unexpected error dropping the table.')

    def replace_data_table(self, create_string: str, old_columns_in_new_str: str) -> None:
        """Replace the old data table with a new one and return the query.

        Args:
            create_string: New table create string.
            old_columns_in_new_str: The old columns we want in the new table.

        Returns:
            See description.
        """
        self.drop_table('data_new')
        with _q(self._db) as q:
            try:
                exec_raises(q, f'CREATE TABLE data_new({create_string})')
                exec_raises(
                    q, f'INSERT INTO data_new ({old_columns_in_new_str}) SELECT {old_columns_in_new_str} FROM data'
                )
                exec_raises(q, 'DROP TABLE data')
                exec_raises(q, 'ALTER TABLE data_new RENAME TO data')
            except QueryError as e:
                self._handle_error(e, 'Unexpected error replacing the data table.')

    def create_indexes(self, ftype: str, table: str) -> None:
        """Create indexes for speed.

        Args:
            ftype: The file type.
            table: Name of the table.
        """
        indexes_strings = database.get_create_index_strings(ftype, table)
        with _q(self._db) as q:
            try:
                for s in indexes_strings:
                    exec_raises(q, s)
            except QueryError as e:
                self._handle_error(e, 'Unexpected error creating indexes.')

    def copy_period(self, column_str: str, from_sp: int, new_column_str: str) -> None:
        """Copy rows from one stress period to another.

        Args:
            column_str: The columns string.
            from_sp: From stress period.
            new_column_str: The new columns string.
        """
        stmt = f'INSERT INTO data ({column_str}) SELECT {new_column_str} FROM data WHERE PERIOD = ?'
        with _q(self._db) as q:
            try:
                q.prepare(stmt)
                q.bindValue(0, from_sp)
                exec_raises(q)
            except QueryError as e:
                self._handle_error(e, 'Unexpected error copying period.')

    def count_rows_in_period(self, period: int) -> int:
        """Return the number of rows in period.

        Args:
            period: The period.

        Returns:
            See description.
        """
        period_rows = 0
        with _q(self._db) as q:
            try:
                q.prepare('SELECT COUNT(*) FROM data WHERE PERIOD = ?')
                q.bindValue(0, period)
                exec_raises(q)
                if q.next():  # noqa B305 (`.next()` is not a thing on Python 3. Use the `next()` builtin.)
                    period_rows = q.value(0)
            except QueryError as e:
                self._handle_error(e, 'Unexpected error counting rows.')
        return period_rows

    def defined_periods(self) -> list[int]:
        """Return the list of unique periods from PERIOD column.

        Returns:
            See description.
        """
        periods = []
        with _q(self._db) as q:
            try:
                q.prepare('SELECT DISTINCT PERIOD FROM data')
                exec_raises(q)
                while q.next():  # noqa B305 (`.next()` is not a thing on Python 3. Use the `next()` builtin.)
                    periods.append(q.value(0))
            except QueryError as e:
                self._handle_error(e, 'Unexpected error determining defined periods.')
        return periods

    def field_count(self, table: str) -> int:
        """Return the number of columns in a table.

        Args:
            table: The table.

        Returns:
            See description.
        """
        count = 0
        with _q(self._db) as q:
            try:
                q.prepare(f'SELECT * from {table}')
                exec_raises(q)
                rec = q.record()
                count = rec.count()
            except QueryError as e:
                self._handle_error(e, 'Unexpected error counting fields.')
        return count

    def write_rows_to_csv(self, fields: list[str], writer) -> None:
        """Writes rows from the table to a .csv file.

        Args:
            fields: The fields to include.
            writer: CSV writer.
        """
        fields_str = ', '.join(fields)
        stmt = f'SELECT {fields_str} FROM (SELECT * FROM data ORDER BY PERIOD, PERIOD_ROW)'
        with _q(self._db) as q:
            try:
                q.prepare(stmt)
                exec_raises(q)
                while q.next():  # noqa B305 (`.next()` is not a thing on Python 3. Use the `next()` builtin.)
                    rec = q.record()
                    writer.writerow([rec.value(i) for i in range(rec.count())])
            except QueryError as e:
                self._handle_error(e, 'Unexpected error writing to CSV.')

    def renumber_period_rows(self, table: str) -> None:
        """Renumber the PERIOD_ROW column.

        Args:
            table: Name of the table.
        """
        try:
            # Select data in order
            with _q(self._db) as select_q:
                exec_raises(select_q, f'SELECT PERIOD, id FROM {table} ORDER BY PERIOD, id')

                # Prepare UPDATE query
                with _q(self._db) as update_q:
                    update_q.prepare(f'UPDATE {table} SET PERIOD_ROW = :new_period_row WHERE id = :id')

                    # Iterate through results and update each row
                    current_period = -1
                    row_counter = 0
                    while select_q.next():
                        period = select_q.value(0)
                        id_ = select_q.value(1)
                        if period != current_period:
                            current_period = period
                            row_counter = 0

                        update_q.bindValue(0, row_counter)
                        update_q.bindValue(1, id_)
                        exec_raises(update_q)
                        row_counter += 1

        except QueryError as e:
            self._handle_error(e, 'Unexpected error renumbering period rows.')

    def insert_rows_from_csv(self, ftype: str, grid_info: GridInfo, reader, table: str) -> None:
        """Insert rows from a .csv file exported by write_rows_to_csv().

        Args:
            ftype: The ftype.
            grid_info: Information about the grid.
            reader: CSV reader.
            table: Name of the table.
        """
        field_count = self.field_count(table)
        stmt = database.setup_insert_statement(field_count, table)
        with _q(self._db) as q:
            q.prepare(stmt)
            period_row = 0  # New rows will all have period_row = 0 initially but we renumber at the end
            for row_idx, row in enumerate(reader):
                period = row[-1]  # PERIOD column is always the last

                # Build full database row (minus last 'id' field) in the correct order. Example:
                # From CSV:    LAY, ROW, COL, Q, PERIOD
                # To db row:   LAY, ROW, COL, Q, CELLIDX, PERIOD, PERIOD_ROW
                new_row = row[:-1]
                try:
                    database.append_cell_idx(ftype, grid_info, new_row)
                    new_row.extend([period, period_row])

                    # Bind the values
                    for i in range(len(new_row)):
                        q.bindValue(i, new_row[i])
                    exec_raises(q)
                except QueryError as e:
                    self._handle_error(
                        e, f'Error importing file on line {row_idx + 2}'
                    )  # 2 for header and 0-based row_idx
        self.renumber_period_rows(table)

    def get_nth_instance_of_stress(
        self, row_id: int, stress_ids: list[int], stress_id_columns: list[str], boundname: str, period: int
    ) -> int:
        """Returns which instance (1-based) of this stress_id (cellidx) this is in this period.

        Args:
            row_id: The id in the model row (1-based).
            stress_ids: The stress ids, e.g. [CELLIDX] or [CELLIDX1, CELLIDX2].
            stress_id_columns: e.g. ['CELLIDX'] or ['CELLIDX1', 'CELLIDX2'].
            boundname: The boundname in the row (may be '').
            period: The period being displayed.

        Returns:
            See description.
        """
        nth = 0
        with _q(self._db) as q:
            try:
                stress_id_str = ' = ? AND '.join(stress_id_columns)
                if boundname:
                    stmt = (
                        f'SELECT id FROM data WHERE {stress_id_str} = ? AND PERIOD = ? AND BOUNDNAME = ?'
                        ' ORDER BY PERIOD_ROW'
                    )
                    q.prepare(stmt)
                    i = 0
                    for i, stress_id in enumerate(stress_ids):
                        q.bindValue(i, stress_id)
                    i += 1
                    q.bindValue(i, period)
                    i += 1
                    q.bindValue(i, boundname)
                else:
                    stmt = f'SELECT id FROM data WHERE {stress_id_str} = ? AND PERIOD = ? ORDER BY PERIOD_ROW'
                    q.prepare(stmt)
                    i = 0
                    for i, stress_id in enumerate(stress_ids):
                        q.bindValue(i, stress_id)
                    i += 1
                    q.bindValue(i, period)

                exec_raises(q)
                while q.next():  # noqa B305 (`.next()` is not a thing on Python 3. Use the `next()` builtin.)
                    nth += 1
                    if row_id == q.value(0):
                        break
            except QueryError as e:
                self._handle_error(e, 'Unexpected error determining stress index.')
        return nth

    def find_stress(
        self, stress_ids: list[int], stress_id_columns: list[str], boundname: str, column_name: str, nth: int
    ) -> tuple[list[int], list[int], list[Any]]:
        """Returns list of ids, periods, and values from rows w/ nth occurrence of stress_id (CELLIDX) over all periods.

        If a period has less than nth occurrences of stress_id (CELLIDX), nothing is returned for that period.

        Args:
            stress_ids (list of int): The stress_ids (typically [CELLIDX] or [CELLIDX1, CELLIDX2]).
            stress_id_columns: e.g. ['CELLIDX'] or ['CELLIDX1', 'CELLIDX2'].
            boundname: The boundname in the row (may be '').
            column_name: If not '', column to include in query results.
            nth: Which instance of the stress to return.

        Returns:
            (tuple): tuple containing:
                - ids (list): Primary key ids.
                - periods (list): Periods.
                - values (list): Values from column.
        """
        ids, periods, values = [], [], []
        select_col_str = '' if not column_name else f', {column_name}'
        stress_id_str = ' = ? AND '.join(stress_id_columns)
        if boundname:
            stmt = (
                f'SELECT id, PERIOD{select_col_str} FROM data WHERE {stress_id_str} = ? AND BOUNDNAME = ?'
                ' ORDER BY PERIOD, PERIOD_ROW'
            )
        else:
            stmt = f'SELECT id, PERIOD{select_col_str} FROM data WHERE {stress_id_str} = ? ORDER BY PERIOD, PERIOD_ROW'

        with _q(self._db) as q:
            try:
                q.prepare(stmt)
                for i, stress_id in enumerate(stress_ids):
                    q.bindValue(i, stress_id)
                if boundname:
                    q.bindValue(i + 1, boundname)
                exec_raises(q)
                ids, periods, values = _nth_query_results(nth, q)
            except QueryError as e:
                self._handle_error(e, 'Unexpected error finding stress.')
        return ids, periods, values


def exec_raises(q: QSqlQuery, s: str = '') -> None:
    """Call exec_() but wrap it so we can detect and raise errors.

    Args:
        q: The query object.
        s: The SQL string.
    """
    rv = q.exec_(s) if s else q.exec_()
    if not rv:
        raise QueryError(q.lastError())
        # # exec_() will sometimes return False even if there was no error, so we check for that
        # q_sql_error: QSqlError = q.lastError()
        # if q_sql_error.type() != QSqlError.NoError:
        #     raise QueryError(q.lastError())


def get_error_str(query_error: QueryError) -> str:
    """Return a string containing info about the error.

    Args:
        query_error: The query error.

    Returns:
        See description.
    """
    error = query_error.qsql_error
    s = 'Query Error:\n'
    s += f'  Type: {error.type()}'
    s += f'  Driver Text: {error.driverText()}'
    s += f'  Database Text: {error.databaseText()}'
    s += f'  Native Code: {error.nativeErrorCode()}'
    return s


def _nth_query_results(nth, q: QSqlQuery) -> tuple[list[int], list[int], list[Any]]:
    """Given a query, returns the nth result from each period.

    If there is no nth result for a period, nothing is returned for that period.

    Args:
        nth (int): 1-based instance of query result we are interested in.
        q: A query that has already been executed.

    Returns:
        (tuple): tuple containing:
            - ids (list): Primary key ids.
            - periods (list): Periods.
            - values (list): Values from column.
    """
    ids, periods, values = [], [], []
    period_dict = {}
    while q.next():  # noqa B305 (`.next()` is not a thing on Python 3. Use the `next()` builtin.)
        period = int(q.value(1))
        if period not in period_dict:
            period_dict[period] = 0
        period_dict[period] += 1
        if period_dict[period] == nth:
            ids.append(int(q.value(0)))
            periods.append(int(q.value(1)))
            values.append(q.value(2))
    return ids, periods, values
