"""Database methods."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import csv
import os
from pathlib import Path
import sqlite3
import sys
import traceback
from typing import Sequence

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules

# 4. Local modules
from xms.mf6.data.grid_info import DisEnum, GridInfo
from xms.mf6.file_io import io_util
from xms.mf6.misc import log_util

XTRA_FIELD_COUNT = 2  # Number of fields we add at the end (period, period_row)
STANDARD_DATABASE_FTYPES = ['CHD6', 'CNC6', 'CTP6', 'DRN6', 'ESL6', 'EVT6', 'GHB6', 'RCH6', 'RIV6', 'SRC6', 'WEL6']


def database_filepath(package_filepath: str | Path) -> str:
    """Returns the database filepath."""
    db_filename = os.path.join(os.path.dirname(package_filepath), 'periods.db')
    return db_filename


def filepath_from_package(package) -> str:
    """Returns the database filepath, or '' if it doesn't exist or the file is empty.

    Args:
        package: Something derived from ListPackageData.

    Returns:
        See description.
    """
    db_filename = package.periods_db
    if not db_filename or not Path(db_filename).exists() or Path(db_filename).stat().st_size == 0:
        return ''
    return db_filename


def set_periods_db_filename(package):
    """Sets the periods.db filename if one exists.

    Args:
        package: Something derived from ListPackageData.
    """
    if not package.periods_db:
        db_filepath = database_filepath(package.filename)
        if os.path.isfile(db_filepath):
            package.periods_db = db_filepath


def build(package, fill: bool = True, db_filepath: str | Path = None) -> str:
    """Builds the database and returns the db_filepath.

    Args:
        package: Something derived from ListPackageData.
        fill: If True, the database is built and filled with existing data.
        db_filepath: The database file path

    Returns:
        Database filepath.
    """
    # Support other int types (flopy started returning these). (https://stackoverflow.com/a/39106289/5666265)
    sqlite3.register_adapter(np.int64, lambda val: int(val))
    sqlite3.register_adapter(np.int32, lambda val: int(val))

    if not db_filepath:
        db_filepath = database_filepath(package.filename)
        package.periods_db = db_filepath

    if os.path.isfile(db_filepath) and os.stat(db_filepath).st_size != 0:
        return db_filepath

    _build(package, db_filepath, fill)
    return db_filepath


def _build(package, db_filename, fill=True):
    """Builds the database file.

    Args:
        package: Something derived from ListPackageData.
        db_filename (str): Path of the database file.
        fill (bool): If True, the database is also filled with existing data.
    """
    try:
        # Create the database and the data table
        with sqlite3.connect(db_filename) as cxn:
            cur = cxn.cursor()
            create_strings, orig_column_count = _create_data_table(package, cur)
            if fill:
                _insert_existing_data(package, len(create_strings), cur, orig_column_count)

            # Add indexes if CELLIDX column is in the table (e.g. OC doesn't have it)
            if any('CELLIDX' in create_string for create_string in create_strings):
                indexes_strings = get_create_index_strings(package.ftype, 'data')
                _add_indexes(indexes_strings, cxn)
            cxn.commit()
    except sqlite3.Error as er:  # pragma no cover
        log_sqlite_error(er)
    except Exception as error:
        raise RuntimeError(str(error))


def add_tops_table(package, cxn: sqlite3.Connection | None) -> None:
    """Adds a 'tops' table to the database containing the top elevations.

    Args:
        package: Something derived from ListPackageData.
        cxn: If provided, an existing connection. Otherwise, a local one is used.
    """
    # Build the create string
    create_strings = _get_cellid_create_strings(package)
    the_rest = ['TOP REAL NOT NULL DEFAULT 0.0', 'CELLIDX INTEGER NOT NULL DEFAULT 0', 'id INTEGER PRIMARY KEY']
    create_strings.extend(the_rest)
    create_string = ', '.join(create_strings)

    try:
        if cxn:
            _add_tops_table_sql(create_string, len(create_strings), package, cxn)
            _add_indexes(_get_cellidx_create_index_strings(package.ftype, 'tops'), cxn)
        else:
            if not (db_filename := filepath_from_package(package)):
                return
            with sqlite3.connect(db_filename) as cxn:
                _add_tops_table_sql(create_string, len(create_strings), package, cxn)
                _add_indexes(_get_cellidx_create_index_strings(package.ftype, 'tops'), cxn)
                cxn.commit()
    except sqlite3.Error as er:  # pragma no cover
        log_sqlite_error(er)
    except Exception as error:
        raise RuntimeError(str(error))


def _add_tops_table_sql(create_string: str, field_count: int, package, cxn) -> None:
    """SQL query to add a 'tops' table to a database.

    Args:
        create_string: The CREATE string.
        field_count: Number of fields.
        package: Something derived from ListPackageData.
        cxn: Sqlite connection.
    """
    cur = cxn.cursor()
    cxn.execute(f'CREATE TABLE tops({create_string})')
    _insert_tops(package, field_count, cur)


def get_create_strings(package, use_aux=True) -> tuple[list[str], list[str], int]:
    """Returns the column strings used in the SQL CREATE TABLE statement.

    Also returns the original number of columns in the file based on the package class.

    Args:
        package: Something derived from ListPackageData.
        use_aux: True to include the auxiliary columns.

    Returns:
        (tuple): tuple containing:
            - create_strings (list(str)): Create strings, like ['LAY INTEGER NOT NULL DEFAULT 0', ...]
            - column_names (list(str)): Names of the table columns
            - orig_column_count (int): Number of columns for the package before we add the extras.
    """
    column_names, column_types, defaults = package.get_column_info('', use_aux=use_aux)
    orig_column_count = len(column_names)
    creation_tuples = []
    for name in column_names:
        if column_types[name] == np.int32:
            creation_tuples.append((name, f'INTEGER NOT NULL DEFAULT {defaults[name]}'))
        elif column_types[name] == np.float64:
            creation_tuples.append((name, f'REAL NOT NULL DEFAULT {defaults[name]}'))
        elif column_types[name] == object:
            creation_tuples.append((name, f"TEXT NOT NULL DEFAULT '{defaults[name]}'"))
        else:
            raise AssertionError()

    # Add extra columns
    if package.ftype in STANDARD_DATABASE_FTYPES:
        creation_tuples.append(('CELLIDX', 'INTEGER NOT NULL DEFAULT 0'))
    elif package.ftype == 'HFB6':
        creation_tuples.append(('CELLIDX1', 'INTEGER NOT NULL DEFAULT 0'))
        creation_tuples.append(('CELLIDX2', 'INTEGER NOT NULL DEFAULT 0'))
    creation_tuples.extend(
        [
            ('PERIOD', 'INTEGER NOT NULL DEFAULT 1'), ('PERIOD_ROW', 'INTEGER NOT NULL DEFAULT 0'),
            ('id', 'INTEGER PRIMARY KEY')
        ]
    )

    column_names = [t[0] for t in creation_tuples]
    create_strings = [f'"{t[0]}" {t[1]}' for t in creation_tuples]

    return create_strings, column_names, orig_column_count


def get_create_index_strings(ftype: str, table: str) -> list[str]:
    """Returns the list of strings needed to create the indexes in the database.

    Args:
        ftype: The file type.
        table: Name of the table

    Returns:
        See description.
    """
    strings = _get_cellidx_create_index_strings(ftype, table)
    strings.extend(_get_period_cellidx_create_index_strings(ftype, table))
    return strings


def append_period_data(package, period: int, rows):
    """Appends the period data to the database.

    Args:
        package: Something derived from ListPackageData.
        period: Stress period.
        rows: List of lists consisting of rows of data.
    """
    try:
        build(package, fill=False)  # Builds one only if one doesn't already exist
        with sqlite3.connect(package.periods_db) as cxn:
            cur = cxn.cursor()

            # Insert data
            create_strings, _, orig_column_count = get_create_strings(package)
            stmt = setup_insert_statement(len(create_strings), 'data')
            row_count = _period_row_count(period, cur)
            _insert_rows(package, cur, stmt, rows, orig_column_count, row_count, period)

            cxn.commit()
    except sqlite3.Error as er:  # pragma no cover
        log_sqlite_error(er)
    except Exception as error:
        raise RuntimeError(str(error))


def cell_idx_from_mfcellid(grid_info: GridInfo, id_values: Sequence, hfb: bool):
    """Return the CELLIDX given the MODFLOW CELLID fields in the file.

    Args:
        grid_info: Number of rows, cols etc.
        id_values: List or tuple of values containing the MODFLOW cell ID info.
        hfb: True if HFB package.

    Returns:
        The CELLIDX.
    """
    if grid_info.dis_enum == DisEnum.DIS:
        pos = [0, 1, 2] if not hfb else [3, 4, 5]
        cell_idx = grid_info.cell_index_from_lay_row_col(
            int(id_values[pos[0]]), int(id_values[pos[1]]), int(id_values[pos[2]])
        )
    elif grid_info.dis_enum == DisEnum.DISV:
        pos = [0, 1] if not hfb else [2, 3]
        cell_idx = grid_info.cell_index_from_lay_cell2d(int(id_values[pos[0]]), int(id_values[pos[1]]))
    else:
        pos = [0] if not hfb else [1]
        cell_idx = int(id_values[pos[0]]) - 1
    return cell_idx


def log_sqlite_error(er: sqlite3.Error) -> None:
    """Logs and sqlite error.

    Args:
        er: The error
    """
    # See https://stackoverflow.com/questions/25371636/how-to-get-sqlite-result-error-codes-in-python
    error = ' '.join(er.args)
    exception_class = er.__class__
    exc_type, exc_value, exc_tb = sys.exc_info()
    traceback_str = traceback.format_exception(exc_type, exc_value, exc_tb)
    log = log_util.get_logger()
    log.error(f'SQLite error: {error}')
    log.error(f'Exception class is: {exception_class}')
    log.error(f'SQLite traceback: {traceback}')
    exc_type, exc_value, exc_tb = sys.exc_info()
    log.error(traceback_str)


def _add_indexes(create_index_strings: list[str], cxn: sqlite3.Connection):
    """Adds the indexes to the database for fast querying.

    Args:
        create_index_strings: Package data.
        cxn: Sqlite connection.
    """
    for s in create_index_strings:
        cxn.execute(s)


def setup_insert_statement(field_count: int, table: str):
    """Sets up the SQL insert statement.

    Args:
        field_count: number of fields
        table: Name of the table

    Returns:
        (str): The statement.
    """
    questions = ', '.join('?' * (field_count - 1))  # ?, ?, ?, ...  (- 1 to not include id field)
    stmt = f'INSERT INTO {table} VALUES ({questions}, NULL);'
    return stmt


def _get_cellid_create_strings(package) -> list[str]:
    """Returns the cell id column create strings used in the SQL CREATE TABLE statement.

    Args:
        package: Something derived from ListPackageData.

    Returns:
        List of create strings (e.g. 'LAY INTEGER NOT NULL DEFAULT 1')
    """
    column_names, column_types, defaults = package.get_id_column_info()
    create_strings = []
    for name in column_names:
        if column_types[name] == np.int32:
            create_strings.append(f'{name} INTEGER NOT NULL DEFAULT {defaults[name]}')
        else:
            raise AssertionError()
    return create_strings


def get_cellidx_index_info(ftype: str, table: str) -> tuple[list[str], list[str]]:
    """Returns the index names and the columns for the CELLIDX indexes.

    Example 1: (['index_data_CELLIDX'], ['CELLIDX'])
    Example 2: (['index_data_CELLIDX1', 'index_data_CELLIDX2'], ['CELLIDX1', 'CELLIDX2'])

    Args:
        ftype: The file type.
        table: Name of the table

    Returns:
        See description.
    """
    columns = ['CELLIDX1', 'CELLIDX2'] if ftype == 'HFB6' else ['CELLIDX']
    index_names = [f'index_{table}_{column}' for column in columns]
    return index_names, columns


def _get_cellidx_create_index_strings(ftype: str, table: str) -> list[str]:
    """Returns the list of strings needed to create the indexes for the CELLIDX column(s).

    Example 1: ['CREATE INDEX index_data_CELLIDX ON data(CELLIDX)']
    Example 2: ['CREATE INDEX index_data_CELLIDX1 ON data(CELLIDX1)',
                'CREATE INDEX index_data_CELLIDX2 ON data(CELLIDX2)']

    Args:
        ftype: The file type.
        table: Name of the table

    Returns:
        See description.
    """
    index_names, columns = get_cellidx_index_info(ftype, table)
    strings = [f'CREATE INDEX {index_name} ON {table}({column})' for index_name, column in zip(index_names, columns)]
    return strings


def _get_period_cellidx_create_index_strings(ftype: str, table: str) -> list[str]:
    """Returns the list of strings needed to create the indexes for the (PERIOD, CELLIDX) columns.

    Args:
        ftype: The file type.
        table: Name of the table

    Returns:
        See description.
    """
    strings = []
    if ftype == 'HFB6':
        strings.append(f'CREATE INDEX index_{table}_PERIOD_CELLIDX1 ON {table}(PERIOD, CELLIDX1)')
        strings.append(f'CREATE INDEX index_{table}_PERIOD_CELLIDX2 ON {table}(PERIOD, CELLIDX2)')
    elif ftype in STANDARD_DATABASE_FTYPES:
        strings.append(f'CREATE INDEX index_{table}_PERIOD_CELLIDX ON {table}(PERIOD, CELLIDX)')
    return strings


def _pad_fields(fields: list[str], question_count: int):
    """Adds empty strings if the number of fields is short.

    SFR PERIOD blocks can have a variable number of columns. Makes sure we have the right number.

    Args:
        fields: List of strings supplied to the stmt.
        question_count: Number of '?' in the stmt.
    """
    shortage = question_count - (len(fields) + XTRA_FIELD_COUNT)
    if shortage > 0:
        fields.extend(['' for _ in range(shortage)])


def _period_row_count(period, cur):
    """Returns the maximum value from the PERIOD_ROW column.

    Args:
        cur: An open cursor.

    Returns:
        (int): See description.
    """
    stmt = 'SELECT COUNT(*) FROM data WHERE PERIOD = ?'
    cur.execute(stmt, (period, ))
    rv = cur.fetchone()
    return rv[0]


def append_cell_idx(ftype: str, grid_info: GridInfo, fields) -> None:
    """Appends the cell index to the list of fields.

    Args:
        ftype: (XML unique name, XML take name parameter)
        grid_info: The GridInfo object.
        fields (list): List of fields.
    """
    if ftype in STANDARD_DATABASE_FTYPES:
        fields.append(cell_idx_from_mfcellid(grid_info, fields, False))
    elif ftype == 'HFB6':
        fields.append(cell_idx_from_mfcellid(grid_info, fields, False))
        fields.append(cell_idx_from_mfcellid(grid_info, fields, True))


def _insert_rows(package, cur, stmt, rows, orig_column_count, period_row_start, period):
    """Inserts the rows.

    Args:
        package: Something derived from ListPackageData.
        cur (sqlite3.Cursor): Database cursor.
        stmt (str): SQL insert statement
        rows (iterable): The rows
        orig_column_count (int): Number of columns for the package before we add the extras.
        period_row_start (int): Current maximum in PERIOD_ROW column.
        period (int): Stress period
    """
    question_count = stmt.count('?')
    for row_idx, row in enumerate(rows):
        fields = row[:orig_column_count]  # Don't allow extra spaces at end to become extra fields
        append_cell_idx(package.ftype, package.grid_info(), fields)
        _pad_fields(fields, question_count)
        period_row = row_idx + period_row_start
        fields.extend([period, period_row])  # If number of items in the list changes, change XTRA_FIELD_COUNT
        try:
            cur.execute(stmt, tuple(fields))
        except sqlite3.Error as er:  # pragma no cover
            log_sqlite_error(er)


def _insert_tops_rows(package, cur: sqlite3.Cursor, stmt: str, tops: list[float]):
    """Inserts the rows.

    Args:
        package: Something derived from ListPackageData.
        cur: Database cursor.
        stmt: SQL insert statement
        tops: The top elevations.
    """
    grid_info = package.grid_info()
    for cell_idx, top in enumerate(tops):
        mfcellid = grid_info.modflow_cellid_from_cell_index(cell_idx)
        fields = [*mfcellid, top, cell_idx]
        cur.execute(stmt, tuple(fields))


def _insert_tops(package, field_count: int, cur: sqlite3.Cursor) -> None:
    """Inserts the tops into the tops table.

    Args:
        package: Something derived from ListPackageData.
        field_count: Number of fields.
        cur: sqlite3 Cursor.
    """
    stmt = setup_insert_statement(field_count, 'tops')
    tops = package.model.get_dis().get_tops()
    _insert_tops_rows(package, cur, stmt, tops)


def _create_data_table(package, cur: sqlite3.Cursor):
    """Builds the database file.

    Args:
        package: Something derived from ListPackageData.
        cur: Cursor

    Returns:
        (tuple): tuple containing:
            - create_strings (list(str)): Create strings, like ['LAY INTEGER NOT NULL DEFAULT 0', ...]
            - orig_column_count (int): Number of columns for the package before we add the extras.
    """
    create_strings, _, orig_column_count = get_create_strings(package)
    create_string = ', '.join(create_strings)
    cur.execute(f'CREATE TABLE data ({create_string})')
    return create_strings, orig_column_count


def _insert_existing_data(package, field_count: int, cur, orig_column_count):
    """Inserts the data in external files into the database.

    Args:
        package: Something derived from ListPackageData.
        field_count: number of fields
        cur: Cursor
        orig_column_count (int): Number of columns for the package before we add the extras.
    """
    stmt = setup_insert_statement(field_count, 'data')
    for period, filename in package.period_files.items():
        if not filename:
            continue

        # Add this period file to the database
        with open(filename) as file:
            reader = csv.reader(file, delimiter=io_util.mfsep, quotechar="'")
            _insert_rows(package, cur, stmt, reader, orig_column_count, 0, period)
