"""Code for reading a mapping table (.cmt) file."""

__copyright__ = "(C) Copyright Aquaveo 2024"
__license__ = "All rights reserved"

# 1. Standard Python modules
from pathlib import Path
import shlex

# 2. Third party modules
import pandas as pd

# 3. Aquaveo modules
from xms.constraint import Grid
from xms.core.filesystem import filesystem
from xms.datasets.dataset_writer import DatasetWriter
from xms.gmi.data.generic_model import Group

# 4. Local modules
from xms.gssha.data import mapping_tables
from xms.gssha.file_io import io_util

# Constants
ID_WIDTH = 6
DESC1_END = 46
DESC2_END = 86


def read(file_path: Path, co_grid: Grid, group: Group, datasets: list[DatasetWriter]) -> dict[str, pd.DataFrame]:
    """Reads MAPPING_TABLE (.cmt) file and returns dict of TIME_SERIES_INDEX and OVERLAND_BOUNDARY tables.

    Args:
        file_path: File path.
        co_grid: The grid.
        group: The generic model group.
        datasets: List of datasets which will be appended to.
    """
    reader = MappingTableFileReader(file_path, co_grid, group, datasets)
    return reader.read()


class MappingTableFileReader:
    """Reads a mapping table (.cmt) file."""
    def __init__(self, file_path: Path, co_grid: Grid, group: Group, datasets: list[DatasetWriter]) -> None:
        """Initializer."""
        self._file_path: Path = file_path
        self._co_grid: Grid = co_grid
        self._group: Group = group
        self._datasets: list[DatasetWriter] = datasets

        self._line_idx: int = 0  # Current line index
        self._lines: list[str] = []  # Lines from the file

    def read(self) -> dict[str, tuple[str, pd.DataFrame]]:
        """Reads MAPPING_TABLE (.cmt) file; returns dict of TIME_SERIES_INDEX and OVERLAND_BOUNDARY tables.

        Returns:
            See description.
        """
        dataset_name_uuids: dict[str, str] = {}  # dataset name -> dataset uuid
        mapping_table_names = mapping_tables.names()
        recognized_table_names = {*mapping_table_names, 'TIME_SERIES_INDEX', 'OVERLAND_BOUNDARY'}
        tables_dict: dict[str, tuple[str, pd.DataFrame]] = {}
        with self._file_path.open('r') as file:
            # Read the whole file because it won't be huge and it makes things easier
            self._lines = file.read().splitlines()  # This strips away '\n' on each line

        while self._line_idx < len(self._lines):
            line = self._next_line()
            if not line:  # Skip blank lines
                continue

            card, card_value = io_util.get_card_and_value(line)
            if card == 'INDEX_MAP':
                self._read_index_map(card_value, dataset_name_uuids)
            elif card in mapping_table_names:
                index_map_name = card_value.strip('"')
                if index_map_name:
                    self._read_mapping_table(card, dataset_name_uuids[index_map_name])
                else:  # empty table. skip to next table
                    self._skip_to_next_table(recognized_table_names)
            elif card == 'TIME_SERIES_INDEX':
                index_map_name = card_value.strip('"')
                table_values = self._read_time_series_table()
                tables_dict[card] = (index_map_name, _dataframe_from_rows(card, table_values))
            elif card == 'OVERLAND_BOUNDARY':
                index_map_name = card_value.strip('"')
                table_values = self._read_table()
                tables_dict[card] = (index_map_name, _dataframe_from_rows(card, table_values))
            else:
                pass
        return tables_dict

    def _skip_to_next_table(self, recognized_table_names: set[str]) -> None:
        """Skips lines until we get to the next table that we support.

        Args:
            recognized_table_names: Set of table names that we support.
        """
        card, _card_value = io_util.get_card_and_value(self._next_line())
        while card not in recognized_table_names and self._line_idx < len(self._lines):
            card, _card_value = io_util.get_card_and_value(self._next_line())
        self._line_idx -= 1  # So next pass we get the same line we're on

    def _read_mapping_table(self, mapping_table_name: str, dataset_uuid: str) -> None:
        """Reads the mapping table.

        Args:
            mapping_table_name: Mapping table name.
            dataset_uuid: uuid of the dataset created from the index map.
        """
        # Set the dataset parameter
        dataset_param_name = mapping_tables.dataset_param(mapping_table_name)
        self._group.parameter(dataset_param_name).value = dataset_uuid

        # Read the table and set the table parameter
        table_values = self._read_table()
        table_param_name = mapping_tables.table_param(mapping_table_name)
        self._group.parameter(table_param_name).value = table_values
        self._group.parameter(table_param_name).table_definition = mapping_tables.table_def(mapping_table_name)

    def _read_index_map(self, card_value: str, dataset_name_uuids: dict[str, str]) -> None:
        """Reads the index map.

        Args:
            card_value: Everything in the line to the right of the card
            dataset_name_uuids: Dict of index map (dataset) name -> dataset uuid.
        """
        path, name = shlex.split(card_value)
        full_path = filesystem.resolve_relative_path(self._file_path.parent, path)
        dataset = io_util.read_grass_file_to_dataset(full_path, name, self._co_grid)
        dataset_name_uuids[name] = dataset.uuid
        self._datasets.append(dataset)

    def _next_line(self) -> str:
        """Returns the next line, advancing the current line index."""
        self._line_idx += 1
        return self._lines[self._line_idx - 1]

    def _read_table(self) -> list[list]:
        """Reads a table from the file at the current position and returns the table values as a 2d list."""
        values = []
        while self._line_idx < len(self._lines):
            line = self._next_line()
            card, card_value = io_util.get_card_and_value(line)
            if card == 'NUM_IDS':
                num_ids = int(card_value)
                self._next_line()  # throw away the ID, DESCRIPTION1, DESCRIPTION2... line
                for _ in range(num_ids):
                    line = self._next_line()
                    index_id = int(line[:ID_WIDTH].strip())
                    desc1 = line[ID_WIDTH:DESC1_END].strip()
                    desc2 = line[DESC1_END:DESC2_END].strip()
                    other = list(map(float, line[DESC2_END:].strip().split()))
                    line_list = [index_id, desc1, desc2]
                    if other:
                        line_list.extend(other)
                    values.append(line_list)
                break
        return values

    def _read_time_series_table(self) -> list[list]:
        """Reads a time series table from the file at the current position and returns the table values as a 2d list."""
        values = []
        while self._line_idx < len(self._lines):
            line = self._next_line()
            card, card_value = io_util.get_card_and_value(line)
            if card == 'NUM_IDS':
                num_ids = int(card_value)
                self._next_line()  # throw away the ID, DESCRIPTION1, DESCRIPTION2... line
                for _ in range(num_ids):
                    line = self._next_line()
                    series_id = int(line[:ID_WIDTH].strip())
                    series_name = line[ID_WIDTH:].strip(' "')
                    values.append([series_id, series_name])
                break
        return values


def _dataframe_from_rows(mapping_table_name: str, table_values: list[list]) -> pd.DataFrame:
    """Converts the rows into a dataframe and returns the dataframe.

    Args:
        mapping_table_name: Name of the table.
        table_values: The table values.

    Returns:
        See description.
    """
    table_def = mapping_tables.table_def(mapping_table_name)
    return table_def.to_pandas(rows=table_values)
