"""Miscellaneous functions dealing with file i/o."""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules
from dataclasses import dataclass
from datetime import datetime, timedelta
from pathlib import Path
from typing import Sequence

# 2. Third party modules
from geopandas import GeoDataFrame
import numpy as np

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.api.tree import tree_util, TreeNode
from xms.constraint import Grid, Orientation
from xms.core.filesystem import filesystem
from xms.datasets.dataset_reader import DatasetReader
from xms.datasets.dataset_writer import DatasetWriter
from xms.grid.ugrid import UGrid

# 4. Local modules
from xms.gssha.data import data_util
from xms.gssha.data.bc_util import get_stream_data
from xms.gssha.mapping import map_util
from xms.gssha.mapping.map_util import ArcIx
from xms.gssha.misc.type_aliases import IntArray

# Constants
INT_WIDTH = 6  # Width of integer number fields

# Type aliases
DictIntXy = dict[int, tuple[Sequence[float], Sequence[float]]]  # int -> xy values


def get_coverage_pointers(sim_node: TreeNode, coverage_type: str) -> list[TreeNode]:
    """Returns the coverage pointers under the sim.

    Returns:
        (list[TreeNode]): See description.
    """
    return tree_util.descendants_of_type(
        tree_root=sim_node,
        xms_types=['TI_COVER_PTR'],
        allow_pointers=True,
        only_first=False,
        coverage_type=coverage_type if coverage_type else None,
        model_name='GSSHA'
    )


def create_masked_array(dataset: DatasetReader, on_off_cells: IntArray) -> 'np.ndarray':
    """Returns a numpy masked array by applying the on_off_cells.

    See https://stackoverflow.com/questions/16724669/numpy-inverse-mask

    Args:
        dataset: The dataset.
        on_off_cells: Model on/off cells.

    Returns:
        See description.
    """
    mask = np.logical_not(on_off_cells)
    masked_array = np.ma.array(dataset.values[0], mask=mask)
    return masked_array


def write_grass_file(co_grid: Grid, ugrid: UGrid, data, path: Path, ints: bool) -> None:
    """Writes a GRASS grid file.

    See https://www.gsshawiki.com/File_Formats:GRASS_files

    Args:
        co_grid: The grid.
        ugrid: The UGrid.
        data: The dataset values, or the dataset.
        path: Path of file to write.
        ints: If True, values are rounded to integers.
    """
    if co_grid.orientation[0] != Orientation.y_decrease or co_grid.orientation[1] != Orientation.x_increase:
        raise ValueError('Unsupported grid orientation')

    on_off_cells = data_util.get_on_off_cells(co_grid, ugrid)

    if isinstance(data, DatasetReader):
        values = data.values[0]
    else:
        values = data

    origin = co_grid.origin
    north = origin[1] + co_grid.locations_y[-1]
    south = origin[1]
    east = origin[0] + co_grid.locations_x[-1]
    west = origin[0]
    rows = len(co_grid.locations_y) - 1
    cols = len(co_grid.locations_x) - 1
    with path.open('w') as file:
        file.write(f'north: {north}\n')
        file.write(f'south: {south}\n')
        file.write(f'east: {east}\n')
        file.write(f'west: {west}\n')
        file.write(f'rows: {rows}\n')
        file.write(f'cols: {cols}\n')
        for ugrid_i in range(1, rows + 1):
            for ugrid_j in range(1, cols + 1):
                cell_idx = co_grid.get_cell_index_from_ij(ugrid_i, ugrid_j)
                if on_off_cells[cell_idx]:
                    if ints:
                        file.write(f'{int(round(values[cell_idx]))} ')
                    else:
                        file.write(f'{values[cell_idx]} ')
                else:
                    if ints:
                        file.write('0 ')
                    else:
                        file.write('0.0 ')
            file.write('\n')


@dataclass
class GrassData:
    """Data read from a grass grid file."""
    north: float = 0.0
    south: float = 0.0
    east: float = 0.0
    west: float = 0.0
    rows: int = 0
    cols: int = 0
    values: 'np.ndarray' = None


def read_grass_file(file_path: str | Path) -> GrassData:
    """Reads a GRASS grid file and returns the data."""
    data = GrassData()
    file_path = Path(file_path)
    if not file_path.is_file() or file_path.stat().st_size == 0:
        raise FileNotFoundError(f'Error: file "{file_path}" does not exist or is empty.')

    try:
        with open(file_path, 'r') as file:
            # Read the header
            data.north = float(file.readline().split()[1])
            data.south = float(file.readline().split()[1])
            data.east = float(file.readline().split()[1])
            data.west = float(file.readline().split()[1])
            data.rows = int(file.readline().split()[1])
            data.cols = int(file.readline().split()[1])

            # Read the values
            values = np.array(dtype=float, object=[line.split() for line in file])
            n_values = data.rows * data.cols
            values = values.reshape(n_values)
            data.values = values
        return data
    except Exception:
        raise RuntimeError(f'Error reading file "{str(file_path)}"')


def write_mask_file(co_grid: Grid, ugrid: UGrid, file_path: Path) -> Path:
    """Writes a mask file (GRASS grid format) and returns the file path."""
    on_off_cells = data_util.get_on_off_cells(co_grid, ugrid)
    write_grass_file(co_grid, ugrid, on_off_cells, file_path, ints=True)
    return file_path


def get_stream_mask(query: Query, co_grid: Grid, ugrid: UGrid, coverage: GeoDataFrame) -> IntArray:
    """Returns a stream mask by setting cells that are intersected by a stream to 1."""
    stream_data = get_stream_data(query, coverage)
    on_off_cells = data_util.get_on_off_cells(co_grid, ugrid)
    arc_ix = map_util.intersect_arcs_with_grid(ugrid, on_off_cells, stream_data)
    return get_arcs_mask(ugrid, arc_ix)


def get_arcs_mask(ugrid: UGrid, arc_ix: ArcIx) -> IntArray:
    """Returns a stream mask by setting cells that are intersected by a stream to 1."""
    mask = np.zeros(ugrid.cell_count, dtype=int)
    for cell_idx, _ix_list in arc_ix.items():
        mask[cell_idx] = 1
    return mask


def get_time_string(start_date_time: datetime, value: float) -> str:
    """Returns the date/time string given the starting datetime and a relative value in minutes.

    Args:
        start_date_time: The starting date/time.
        value: The relative time.

    Returns:
        See description.
    """
    time_change = timedelta(minutes=value)
    date_time = start_date_time + time_change
    return date_time.strftime('%Y %m %d %H %M')


def datetime_from_string(string: str) -> datetime | None:
    """Reads the date/time from the string and returns a datetime, or None if it couldn't.

    Args:
        string: The string.

    Returns:
        See description.
    """
    try:
        return datetime.strptime(string, '%Y %m %d %H %M')
    except ValueError:
        return None


def datetime_from_cards(start_date: str, start_time: str) -> datetime | None:
    """Reads START_DATE and/or START_TIME and returns a datetime object.

    Args:
        start_date: START_DATE card value from .gssha file (or '' if not present).
        start_time: START_TIME card value from .gssha file (or '' if not present).
    """
    dt_start_date_time: datetime | None = None
    if start_date:
        dt_start_date_time = datetime.strptime(start_date, '%Y %m %d')
        if start_time:
            dt_start_time = datetime.strptime(start_time, '%H %M')
            dt_start_date_time = datetime.combine(dt_start_date_time, dt_start_time.time())
    return dt_start_date_time


def relative_time_from_datetime(date_time: datetime, ref_time: datetime) -> float:
    """Converts the absolute date/time to relative using start and returns the minutes as a float.

    Args:
        date_time: The date/time.
        ref_time: Reference time.

    Returns:
        See description.
    """
    minutes = (date_time - ref_time).total_seconds() / 60.0
    return minutes


def dataset_from_grass_data(name: str, geom_uuid: str, grass_data: GrassData, on_off_cells: IntArray) -> DatasetWriter:
    """Creates a dataset from the grass file data."""
    dataset_writer = DatasetWriter(name=name, geom_uuid=geom_uuid, location='cells', use_activity_as_null=True)
    dataset_writer.append_timestep(0.0, grass_data.values, on_off_cells)
    dataset_writer.appending_finished()
    return dataset_writer


def read_grass_file_to_dataset(file_path: str | Path, name: str, co_grid: Grid) -> DatasetWriter:
    """Reads the grass file into a dataset and returns the dataset.

    Args:
        file_path: Path to the grass file containing the dataset values.
        name: Name to give to the dataset.
        co_grid: The grid.

    Returns:
        See description.
    """
    grass_data = read_grass_file(file_path)
    on_off_cells = co_grid.model_on_off_cells
    dataset = dataset_from_grass_data(name, co_grid.uuid, grass_data, on_off_cells)
    return dataset


def get_card_and_value(line: str) -> tuple[str, str]:
    """Given a line from the file, returns a tuple of: card - the first word, value - everything else.

    Both card and card value are made upper case.

    Args:
        line: A line from a file.

    Returns:
        See description.
    """
    card_and_value = line.rstrip('\n').split(' ', 1)  # Split line into 2: card, everything else
    card = card_and_value[0].upper()
    card_value = card_and_value[1].strip() if len(card_and_value) > 1 else ''
    return card, card_value


def read_gssha_file_to_dict(gssha_file_path: Path) -> dict[str, str]:
    """Reads the .gssha file into a dict of card -> str, and returns the dict.

    Args:
        gssha_file_path:

    Returns:
        See description.
    """
    gssha_dict: dict[str, str] = {}
    with gssha_file_path.open('r') as file:
        for line in file:
            card, card_value = get_card_and_value(line)
            gssha_dict[card] = card_value
    return gssha_dict


def get_full_path(gssha_file_path: Path, project_path: str | None, file_path: str) -> Path:
    """Resolves relative paths if necessary to return the full path to file_path.

    Args:
        gssha_file_path: File path to .gssha file.
        project_path: PROJECT_PATH card value from .gssha file, or '' if it isn't present.
        file_path: A file path from the .gssha file, which may be a full path or a relative path.

    Returns:
        The full path to file_path.
    """
    if project_path and file_path:
        full_path = filesystem.resolve_relative_path(project_path.strip('"'), file_path.strip('"'))
    else:
        full_path = filesystem.resolve_relative_path(gssha_file_path.parent, file_path.strip('"'))
    return Path(full_path)
