"""MappingTableFileWriter class."""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules
import json
import logging
from pathlib import Path

# 2. Third party modules
import numpy as np  # noqa: F401 - Used in type hinting
import pandas as pd

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.api.tree import tree_util
from xms.constraint import Grid
from xms.coverage.xy.xy_util import XySeriesDict
# from xms.datasets.dataset_reader import DatasetReader
from xms.gmi.data.generic_model import GenericModel, Group
from xms.grid.ugrid import UGrid

# 4. Local modules
from xms.gssha.data import bc_generic_model, data_util, mapping_tables
from xms.gssha.data.bc_util import BcData
from xms.gssha.data.sim_generic_model import InfilType
from xms.gssha.file_io import io_util
from xms.gssha.file_io.block import Block
from xms.gssha.file_io.io_util import INT_WIDTH
from xms.gssha.mapping import map_util
from xms.gssha.misc.type_aliases import IntArray

# Constants
MIN_FLOAT_WIDTH = 14  # Minimum width for floating point numbers

# Type aliases
MappingTable = str
DatasetUuidOrValues = 'str | np.ndarray'
IndexMapName = str
TableDf = pd.DataFrame
MappingTableData = dict[MappingTable, tuple[DatasetUuidOrValues, IndexMapName, TableDf]]
IndexMaps = list[tuple['DatasetReader | np.ndarray', str]]


def write(
    gssha_file_path: Path, generic_model: GenericModel, co_grid: Grid, ugrid: UGrid, olf_data: BcData,
    xy_dict: XySeriesDict, xy_old_to_new_id: dict[int, int], query: Query
) -> Path | None:
    """Writes the MAPPING_TABLE (.cmt) file and returns the file path.

    Args:
        gssha_file_path (str | Path): .gssha file path.
        generic_model: The generic model.
        co_grid: The grid.
        ugrid: The UGrid.
        olf_data: Overland flow bc data.
        xy_dict: xy series dict.
        xy_old_to_new_id: xy series id in bc -> new xy series id in xy_dict
        query (Query): Object for communicating with XMS
    """
    writer = MappingTableFileWriter(
        gssha_file_path, generic_model, co_grid, ugrid, olf_data, xy_dict, xy_old_to_new_id, query
    )
    return writer.write()


class MappingTableFileWriter:
    """Writes the MAPPING_TABLE (.cmt) file."""
    def __init__(
        self, gssha_file_path: Path, generic_model: GenericModel, co_grid: Grid, ugrid: UGrid, olf_data: BcData,
        xy_dict: XySeriesDict, xy_old_to_new_id: dict[int, int], query: Query
    ) -> None:
        """Initializes the class.

        Args:
            gssha_file_path (str | Path): .gssha file path.
            generic_model: The generic model.
            co_grid: The grid.
            ugrid: The UGrid.
            olf_data: Overland flow bc data.
            xy_dict: xy series dict.
            xy_old_to_new_id: xy series id in bc -> new xy series id in xy_dict
            query (Query): Object for communicating with XMS
        """
        super().__init__()
        self._cmt_file_path: Path = gssha_file_path.with_suffix('.cmt')
        self._gm = generic_model
        self._co_grid = co_grid
        self._ugrid = ugrid
        self._olf_data = olf_data  # overland flow bc data
        self._xy_dict = xy_dict
        self._xy_old_to_new_id = xy_old_to_new_id
        self._query = query

        self._log = logging.getLogger('xms.gssha')
        self._file = None
        self._group = self._gm.model_parameters.group('mapping_tables')
        self._mapping_table_data: MappingTableData = {}
        self._index_maps: IndexMaps = []
        self._time_series_index_table: dict[int, str] = {}  # for TIME_SERIES_INDEX table
        self._olf_type_params = bc_generic_model.olf_type_params
        self._variable_olf_types = bc_generic_model.get_olf_types('variable')

    def write(self) -> Path | None:
        """Writes the MAPPING_TABLE (.cmt) file and returns the file path."""
        self._log.info('Writing .cmt file...')
        self._collect_mapping_table_data()
        self._collect_index_maps()
        with open(self._cmt_file_path, 'w') as self._file:
            self._file.write('GSSHA_INDEX_MAP_TABLES\n')
            self._write_index_maps()
            self._write_time_series_index_table()
            self._write_tables()

        return self._cmt_file_path

    def _write_index_maps(self) -> None:
        """Writes the list of INDEX_MAP lines and the index map files."""
        with Block(self._file, self._log, '') as block:
            for data, index_map_name in self._index_maps:
                path = self._cmt_file_path.parent / f'{index_map_name}.idx'
                block.write(group=None, name='INDEX_MAP', value=f'"{path.name}" "{index_map_name}"')
                io_util.write_grass_file(self._co_grid, self._ugrid, data, path, ints=True)

    def _write_time_series_index_table(self) -> None:
        """Writes the TIME_SERIES_INDEX table.

        See https://www.gsshawiki.com/Surface_Water_Routing:Overland_Boundary_Conditions#Time_Series_Index_Mapping_Table
        """
        if not self._time_series_index_table:
            return

        with Block(self._file, self._log, '\n') as block:
            block.write(None, 'TIME_SERIES_INDEX')
            block.write(None, f'NUM_IDS {len(self._time_series_index_table)}')
            block.write(None, 'ID    Time series name')
            for bc_id, series_name in self._time_series_index_table.items():
                block.write(None, '', f'{bc_id:<{INT_WIDTH}}"{series_name}"')

    def _write_tables(self) -> None:
        """Writes the tables."""
        variable_olf_param_names = _get_variable_olf_param_names()
        for mapping_table_name, (_data, index_map_name, df) in self._mapping_table_data.items():
            with Block(self._file, self._log, '\n') as block:
                block.write(group=None, name=mapping_table_name, value=f'"{index_map_name}"')
                block.write(group=None, name='NUM_IDS', value=f'{str(len(df))}')
                self._write_column_headers(mapping_table_name, df, block)
                for _index, row in df.iterrows():
                    int_format = False
                    if mapping_table_name == 'OVERLAND_BOUNDARY' and row[df.columns[1]] in variable_olf_param_names:
                        int_format = True
                    self._write_table_line(row, df, int_format, block)

    def _write_column_headers(self, mapping_table_name: str, df: pd.DataFrame, block: Block) -> None:
        """Writes the column headers line, starting with 'ID....

        Args:
            mapping_table_name: Mapping table name.
            df: The dataframe.
            block: The block being written to.
        """
        if mapping_table_name == 'TIME_SERIES_INDEX':
            line = f'{"ID":<6}"Time series name"'
        else:
            line = f'{"ID":<6}{"DESCRIPTION1":<40}{"DESCRIPTION2":<40}'

        for column in df.columns[3:]:
            width = max(len(column), MIN_FLOAT_WIDTH)
            line += f'{column:<{width}}'
        block.write(group=None, name='', value=line)

    def _write_table_line(self, row, df: pd.DataFrame, int_format: bool, block: Block) -> None:
        """Writes a line of the table.

        Args:
            row: A row from the dataframe.
            df: The dataframe.
            int_format: True to format the numbers as integers.
            block: The block being written to.
        """
        line = f'{row[df.columns[0]]:<6}{row[df.columns[1]]:<40}{row[df.columns[2]]:<40}'
        for column in df.columns[3:]:
            width = max(len(column), MIN_FLOAT_WIDTH)
            if int_format:
                line += f'{int(row[column]):<{width}}'
            else:
                line += f'{row[column]:<{width}}'
        block.write(group=None, name='', value=line)

    def _model_control_option_on(self, mapping_table_name: str) -> bool:
        """Returns True if the overland flow option corresponding to the mapping table is on.

        Some tables should only be written if the corresponding option is set in model control.

        Args:
            mapping_table_name: Mapping table name.

        Returns:
            See description.
        """
        section = self._gm.global_parameters
        if mapping_table_name == 'ROUGH_EXP':
            return section.group('overland_flow').parameter('ROUGH_EXP').value
        elif mapping_table_name == 'INTERCEPTION':
            return section.group('overland_flow').parameter('INTERCEPTION').value
        elif mapping_table_name == 'RETENTION':
            return section.group('overland_flow').parameter('RETENTION').value
        elif mapping_table_name == 'GREEN_AMPT_INFILTRATION':
            return section.group('infiltration').parameter('infil_type').value in {
                InfilType.INF_REDIST, InfilType.INF_LAYERED_SOIL
            }
        elif mapping_table_name == 'GREEN_AMPT_INITIAL_SOIL_MOISTURE':
            return section.group('infiltration').parameter('infil_type').value in {
                InfilType.INF_REDIST, InfilType.INF_LAYERED_SOIL
            }
        elif mapping_table_name in {'RICHARDS_EQN_INFILTRATION_BROOKS', 'RICHARDS_EQN_INFILTRATION_HAVERCAMP'}:
            return section.group('infiltration').parameter('infil_type').value == InfilType.INF_RICHARDS
        elif mapping_table_name == 'AREA_REDUCTION':
            return section.group('overland_flow').parameter('AREA_REDUCTION').value
        else:
            return True

    def _collect_mapping_table_data(self) -> None:
        """Populates self._mapping_table_data."""
        mapping_table_names = mapping_tables.names()
        for mapping_table_name in mapping_table_names:
            if not self._model_control_option_on(mapping_table_name):
                continue

            dataset_param = mapping_tables.dataset_param(mapping_table_name)
            dataset_uuid = self._group.parameter(dataset_param).value
            if dataset_uuid:
                # Get dataset name
                dataset_node = tree_util.find_tree_node_by_uuid(self._query.project_tree, dataset_uuid)
                if dataset_node is None:
                    self._log.error(f'Index map dataset not found for "{mapping_table_name}" table. Skipping.')
                    continue
                index_map_name = dataset_node.name

                # Get table data into a dataframe
                table_param = mapping_tables.table_param(mapping_table_name)
                table_rows = self._group.parameter(table_param).value
                df = self._group.parameter(table_param).table_definition.to_pandas(rows=table_rows)

                self._mapping_table_data[mapping_table_name] = (dataset_uuid, index_map_name, df)

        self._add_overland_flow_bcs()

    def _collect_index_maps(self) -> IndexMaps:
        """Populates self._index_maps."""
        unique_uuids: set[str] = set()
        for _mapping_table_name, (data, index_map_name, _df) in self._mapping_table_data.items():
            if isinstance(data, str):  # data is a dataset uuid
                if data not in unique_uuids:
                    dataset = self._query.item_with_uuid(data)
                    self._index_maps.append((dataset, index_map_name))
            else:  # data is a np.ndarray
                self._index_maps.append((data, index_map_name))
        return self._index_maps

    def _add_overland_flow_bcs(self) -> None:
        """Adds the data for overland flow bcs, if any."""
        if not self._olf_data or not self._olf_data.feature_bc:
            return

        bc_ids = self._get_bc_ids()
        mask = self._build_bc_mask(bc_ids)
        df = self._build_bc_dataframe(bc_ids)
        self._mapping_table_data['OVERLAND_BOUNDARY'] = (mask, f'{self._cmt_file_path.stem}_olf_bc', df)

    def _get_bc_ids(self) -> dict[Group, int]:
        """Assign id numbers to bcs, and bcs with identical values share id numbers."""
        bc_ids: dict[Group, int] = {}
        bc_id = 1  # Makes first number 2, but that's OK because 'normal' is 1
        unique_bc_vals: set[str] = set()  # To determine uniqueness of values
        for _feature, group in self._olf_data.feature_bc.items():
            bc_id = self._next_bc_id(group, bc_id, unique_bc_vals)
            bc_ids[group] = bc_id
        return bc_ids

    def _next_bc_id(self, group: Group, bc_id: int, unique_bc_vals: set[str]) -> int:
        """Returns the next bc id number, reusing id numbers for bcs with identical values.

        Args:
            group: The GMI group.
            bc_id: The last bc id number.
            unique_bc_vals: Set of values used to determine uniqueness of values.

        Returns:
            See description.
        """
        # Get the values as a dict
        olf_type = group.parameter('overland_flow_type').value
        values_dict = {'overland_flow_type': olf_type}
        value_param_name = self._olf_type_params[olf_type]
        values_dict[value_param_name] = group.parameter(value_param_name).value
        if olf_type in self._variable_olf_types:  # If it's an xy series, use the new xy series id
            values_dict[value_param_name] = self._xy_old_to_new_id[values_dict[value_param_name]]

        # Get values as a string and use a set to check for uniqueness of the values
        values_as_json_str = json.dumps(values_dict, sort_keys=True)
        if values_as_json_str not in unique_bc_vals:  # If this set of values is unique, add it and get a new id
            bc_id += 1
            unique_bc_vals.add(values_as_json_str)
        return bc_id

    def _build_bc_mask(self, bc_ids: dict[Group, int]) -> IntArray:
        """Builds and returns the bc mask."""
        mask = data_util.get_on_off_cells(self._co_grid, self._ugrid)

        # Add any arcs
        arc_ix = map_util.intersect_arcs_with_grid(self._ugrid, mask, self._olf_data)
        for cell_idx, arc_info_list in arc_ix.items():
            group = arc_info_list[-1][0]  # -1 = last one if multiple bcs for the same cell_idx
            mask[cell_idx] = bc_ids[group]

        # Add any points
        point_ix = map_util.intersect_points_with_grid(self._co_grid, mask, self._olf_data)
        for cell_idx, group_list in point_ix.items():
            group = group_list[-1]  # -1 = last one if multiple bcs for the same cell_idx
            mask[cell_idx] = bc_ids[group]

        return mask

    def _build_bc_dataframe(self, bc_ids: dict[Group, int]) -> pd.DataFrame:
        """Builds and returns the dataframe for the bcs."""
        bdy_type = {
            'constant_slope': 1,
            'constant_stage': 2,
            'variable_stage': 3,
            'variable_flow_cms': 4,
            'variable_flow_cfs': 5
        }
        df_dict = {
            'Index value': [1],
            'Description1': ['normal'],
            'Description2': [''],
            'BDY_TYPE': [0],
            'BDY_PARAM': [0.0]
        }
        for group, bc_id in bc_ids.items():
            olf_type = group.parameter('overland_flow_type').value
            value_param_name = self._olf_type_params[olf_type]
            df_dict['Index value'].append(bc_id)
            df_dict['Description1'].append(value_param_name)
            df_dict['Description2'].append('')
            df_dict['BDY_TYPE'].append(bdy_type[value_param_name])

            value = group.parameter(value_param_name).value
            if olf_type in self._variable_olf_types:
                new_xy_series_id = self._xy_old_to_new_id[value]
                df_dict['BDY_PARAM'].append(new_xy_series_id)
                self._time_series_index_table[bc_id] = self._xy_dict[new_xy_series_id].name
            else:
                df_dict['BDY_PARAM'].append(value)
        df = pd.DataFrame(df_dict)
        return df


def _get_variable_olf_param_names() -> list[str]:
    """Returns a list of the variable (use xy series) overland flow parameter names."""
    # Get the OLF variable parameter names
    olf_type_params = bc_generic_model.olf_type_params
    variable_olf_types = bc_generic_model.get_olf_types('variable')
    variable_param_names = [olf_type_params[olf_type] for olf_type in variable_olf_types]
    return variable_param_names
