"""MappingTableCreator class."""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules
import csv
from dataclasses import dataclass
import logging
from pathlib import Path
from typing import Sequence

# 2. Third party modules
import numpy as np
import pandas as pd

# 3. Aquaveo modules
from xms.api.tree import TreeNode
from xms.constraint import Grid
from xms.datasets.dataset_writer import DatasetWriter
from xms.gdal.rasters import RasterInput
from xms.gdal.utilities import gdal_utils as gu
from xms.gdal.vectors import VectorInput
from xms.grid.ugrid import UGrid
from xms.tool.algorithms.geometry import geometry
from xms.tool_core.table_definition import TableDefinition

# 4. Local modules
from xms.gssha.data import data_util
from xms.gssha.data import mapping_tables
from xms.gssha.gui import gui_util
from xms.gssha.mapping.map_util import Pt3d
from xms.gssha.misc.type_aliases import IntArray, Pt3dArray
from xms.gssha.tools import tool_util

# Constants
COL1_IDX = 'Index value'  # First table column
COL2_DESC1 = 'Description1'  # Second table column
COL3_DESC2 = 'Description2'  # Third table column

# Soil table column headers
SOIL_COL_FILE = 'Shapefile'

# Land use table column headers
LU_COL_TYPE = 'File Type'
LU_COL_FILE = 'File'
LU_COL_CSV_FILE = 'CSV File'
LU_COL_LU_CODE = 'LUCODE Field'
LU_COL_CHOICES = 'Field Choices'

# File types
SHAPEFILE = 'shapefile'
RASTER = 'raster'

# Type aliases
PolyPts = Sequence[Pt3d]  # One loop
PolygonPts = Sequence[PolyPts]  # Outer polygon followed by inner polygons (if any)
ComboIds = dict[tuple[int, int], int]  # (soil id, land use id) -> combined id
IdLuCodes = dict[int, tuple[int, str]]  # index map id -> (lu code, lu code name)


@dataclass
class ToolInputs:
    """Tool input values."""
    mapping_table_name: str = ''
    co_grid: 'Grid | None' = None
    ugrid: 'UGrid | None' = None
    ugrid_node: 'TreeNode | None' = None
    use_soil_data: bool = False
    soil_table: 'pd.DataFrame | None' = None
    use_land_use_data: bool = False
    land_use_table: 'pd.DataFrame | None' = None
    index_map_name: str = ''
    wkt: str = ''  # Display projection well known text
    vertical_units: str = ''  # Vertical units from tool


@dataclass
class RunResults:
    """The results."""
    index_map_dataset: 'DatasetWriter | None' = None
    table_definition: 'TableDefinition | None' = None
    df: 'pd.DataFrame | None' = None  # Table values


def run(inputs: ToolInputs, logger: logging.Logger) -> RunResults:
    """Creates an index map and mapping table from a UGrid, shapefile, and optionally a raster and csv file."""
    runner = MappingTableCreator(inputs, logger)
    return runner.run()


class MappingTableCreator:
    """Creates an index map and mapping table from a UGrid, shapefile, and optionally a raster and csv file."""
    def __init__(self, inputs: ToolInputs, logger: logging.Logger) -> None:
        """Initializes the class.

        Args:
            inputs: The tool input values.
            logger: The logger.
        """
        self._inputs = inputs
        self._logger = logger

        self._cell_centers: 'Pt3dArray | None' = None
        self._on_off_cells: 'IntArray | None' = None
        self._table_def = mapping_tables.table_def(self._inputs.mapping_table_name)
        self._column_shape_att = {
            'HYDR_COND': 'ksat',
            'HYD_COND': 'ksat',
            'FIELD_CAPACITY': 'fieldcap',
            'WILT_POINT': 'wiltingpt',
            'SOIL_MOIST': 'moisture',
            'SOIL_MOISTURE': 'moisture'
        }  # Column header -> shapefile attribute
        self._columns_4n: dict[str, dict[int, float]]  # Values for table columns past the descriptions (4 to n)
        self._index_map_ids: 'list[int] | ComboIds | None' = None
        self._index_map: 'IntArray | None' = None

        # Soil data
        self._texture_ids: dict[str, int] = {}  # texture name -> index map id
        self._next_texture_id: int = 1

        # Land use data
        self._lu_code_ids: dict[int, int] = {}  # land use code -> index map id
        self._next_lu_code_id: int = 1
        self._id_lu_codes: IdLuCodes = {}  # index map id -> (lu code, lu code name)

        self._results: RunResults = RunResults()

    def run(self) -> RunResults:
        """Creates an index map and mapping table from a UGrid, shapefile, and optionally a raster and csv file."""
        self._init_columns_4n()
        self._compute_cell_centers()
        self._get_on_off_cells()
        self._build_index_map()
        self._make_dataset_from_index_map()
        self._build_table()
        self._logger.info('Finished.')
        return self._results

    def _init_columns_4n(self) -> None:
        """Initializes the data structure for the columns after the description columns."""
        self._columns_4n = {column.header: {} for column in self._table_def.column_types[3:]}

    def _compute_cell_centers(self) -> None:
        """Computes the cell centers."""
        self._logger.info('Computing cell centers...')
        self._cell_centers = tool_util.compute_cell_centers(self._inputs.ugrid)

    def _get_on_off_cells(self) -> None:
        """Gets the on/off cells array."""
        self._logger.info('Getting model on/off cells array...')
        self._on_off_cells = data_util.get_on_off_cells(self._inputs.co_grid, self._inputs.ugrid)

    def _build_index_map(self) -> None:
        """Builds the index map."""
        self._logger.info('Building the index map...')
        soil_index_map = None
        land_use_index_map = None
        if self._inputs.use_soil_data:
            soil_index_map = self._index_map_from_soil_shapefiles()
        if self._inputs.use_land_use_data:
            land_use_index_map = self._index_map_from_land_use_data()

        if self._inputs.use_soil_data and self._inputs.use_land_use_data:
            self._index_map, self._index_map_ids = self._combine_index_maps(soil_index_map, land_use_index_map)
        elif soil_index_map is not None:
            self._index_map = soil_index_map
            self._index_map_ids = _non_zero_ids_sorted(self._index_map)
        elif land_use_index_map is not None:
            self._index_map = land_use_index_map
            self._index_map_ids = _non_zero_ids_sorted(self._index_map)

    def _make_dataset_from_index_map(self) -> None:
        """Makes a dataset from the index map."""
        self._logger.info(f'Creating dataset "{self._inputs.index_map_name}"...')
        name = gui_util.find_unique_child_name(self._inputs.ugrid_node, self._inputs.index_map_name)
        dataset_writer = DatasetWriter(name=name, geom_uuid=self._inputs.co_grid.uuid, location='cells')
        dataset_writer.append_timestep(0., self._index_map)
        dataset_writer.appending_finished()
        self._results.index_map_dataset = dataset_writer

    def _build_table(self) -> None:
        """Builds the table."""
        self._logger.info('Building the table...')

        index_map_ids = self._index_map_ids  # for short

        # Build dict of values
        nrows = len(index_map_ids)
        df_dict = {column.header: [column.default] * nrows for column in self._table_def.column_types}

        # 1st column - index value
        df_dict[COL1_IDX] = list(index_map_ids.values()) if isinstance(index_map_ids, dict) else index_map_ids

        # 2nd column - description 1
        if self._inputs.use_soil_data:
            df_dict[COL2_DESC1] = _description1_list(index_map_ids, self._texture_ids)

        # 3rd column - description 2
        if self._inputs.use_land_use_data:
            df_dict[COL3_DESC2] = _description2_list(index_map_ids, self._id_lu_codes)

        # The rest of the columns
        columns_4n = self._columns_4n  # for short
        for column in self._table_def.column_types[3:]:
            header = column.header  # for short
            if header in self._columns_4n:
                if isinstance(index_map_ids, dict):
                    df_dict[header] = [columns_4n[header].get(soil_id, 0.0) for soil_id, _ in index_map_ids.keys()]
                else:
                    df_dict[header] = [columns_4n[header].get(index_id, 0.0) for index_id in index_map_ids]

        df = pd.DataFrame(df_dict)
        df.index += 1
        self._results.table_definition = self._table_def
        self._results.df = df

    def _index_map_from_soil_shapefiles(self) -> IntArray:
        """Returns the index map array created from the polygons in the shapefiles."""
        self._logger.info('Creating index map from shapefiles...')

        index_map = np.zeros(len(self._cell_centers), dtype=int)

        # Loop through shapefiles
        for shapefile_path in self._inputs.soil_table[SOIL_COL_FILE]:
            shape_index_map = self._index_map_from_shapefile(shapefile_path)
            if shape_index_map is not None:
                index_map = np.where(shape_index_map > 0, shape_index_map, index_map)

        return index_map

    def _index_map_from_shapefile(
        self, shapefile_path: str, lu_code_field: str = '', csv_file_path: str = ''
    ) -> 'IntArray | None':
        """Returns the index map array created from the polygons in the shapefile.

        Args:
            shapefile_path: File path.
            lu_code_field: If doing land use, name of the field with the land use codes.
            csv_file_path: If doing land use, file path of CSV file relating land use codes to land use names

        Returns:
            Index map (IntArray) or, if problems, None
        """
        self._logger.info(f'Reading {shapefile_path} ...')

        doing_soil = lu_code_field == ''  # True if doing soil data, False if doing land use

        # Open file
        try:
            vi = VectorInput(shapefile_path)
        except ValueError:
            self._logger.warning(f'Could not open {shapefile_path} as a shapefile. Skipped.')
            return None

        # Make sure it's a polygon shapefile
        if not tool_util.is_polygon_shapefile(vi):
            self._logger.warning(f'"{shapefile_path}" is not a polygon shapefile. Skipped.')
            return None

        # Make sure it has the field we need
        field_names = {field.name for field in vi.get_fields()}
        field = 'TEXTURE' if doing_soil else lu_code_field
        if field not in field_names:
            self._logger.warning(f'"{shapefile_path}" does not have "{field}" field. Skipped.')
            return None

        # Read the CSV file if doing land use
        csv_lu_code_names = _read_csv_file_or_log_error(csv_file_path, self._logger) if not doing_soil else {}
        if not doing_soil and not csv_lu_code_names:
            return None

        # Loop through polygons
        attributes = vi.get_attributes()
        self._logger.info('Finding cells in polygons...')
        index_map = np.zeros(len(self._cell_centers), dtype=int)
        polygons = vi.get_poly_features()
        coord_trans = None
        if vi.wkt and self._inputs.wkt:
            coord_trans = gu.get_coordinate_transformation(vi.wkt, self._inputs.wkt)
        for i, polygon in enumerate(polygons):

            # Get the cells in the polygon
            in_poly = _get_cells_in_polygon(self._cell_centers, self._on_off_cells, polygon, coord_trans)
            if max(in_poly) == 0:  # No cells in polygon. We can skip it
                continue

            # Get the index id for this polygon
            if doing_soil:
                texture = attributes['texture'][i]
                index_id = self._index_id_from_texture(texture)
            else:
                lu_code = attributes[lu_code_field.lower()][i]
                index_id = self._index_id_from_lu_code(lu_code)
                self._id_lu_codes[index_id] = lu_code, csv_lu_code_names.get(lu_code, '')

            # Get values for other table columns, if available
            if index_id != 0:
                self._save_columns_4n_values(attributes, i, index_id)

            cell_poly_ids = index_id * in_poly  # Create array of ints from array of bool
            index_map = np.where(cell_poly_ids > 0, cell_poly_ids, index_map)
        return index_map

    def _save_columns_4n_values(self, attributes: dict, shape_idx: int, index_id: int) -> None:
        """Collect values for columns 4 to n if the shapefile includes the data.

        Args:
            attributes: The shapefile attributes.
            shape_idx: Index of the polygon.
            index_id: The index map id.
        """
        for column_header in self._columns_4n.keys():
            field_name = self._column_shape_att.get(column_header)
            if field_name and field_name in attributes:
                self._columns_4n[column_header][index_id] = attributes[field_name][shape_idx]  # last one wins
            else:
                self._columns_4n[column_header][index_id] = 0.0

    def _index_id_from_texture(self, texture: str):
        """Returns the index map ID for the polygon, based on the texture.

        Args:
            texture: The texture value from the shapefile.

        Returns:
            See description.
        """
        index_id, self._next_texture_id = _index_id_from_value(texture, self._texture_ids, self._next_texture_id)
        return index_id

    def _index_id_from_lu_code(self, lu_code: int) -> int:
        """Returns the index map ID for the polygon, based on the land use code.

        Args:
            lu_code: The land use code.

        Returns:
            See description.
        """
        index_id, self._next_lu_code_id = _index_id_from_value(lu_code, self._lu_code_ids, self._next_lu_code_id)
        return index_id

    def _index_map_from_land_use_data(self) -> IntArray:
        """Creates and returns an index map based on the land use inputs.

        Last file wins if the files overlap.
        """
        index_map = np.zeros(len(self._cell_centers), dtype=int)
        df = self._inputs.land_use_table
        for _index, row in df.iterrows():
            if row[LU_COL_TYPE] == SHAPEFILE:
                field = row[LU_COL_LU_CODE]
                shape_index_map = self._index_map_from_shapefile(row[LU_COL_FILE], field, row[LU_COL_CSV_FILE])
                if shape_index_map is not None:
                    index_map = np.where(shape_index_map > 0, shape_index_map, index_map)
            else:
                raster_index_map = self._index_map_from_raster(row[LU_COL_FILE], row[LU_COL_CSV_FILE])
                if raster_index_map is not None:
                    index_map = np.where(raster_index_map > 0, raster_index_map, index_map)
        self._renumber_land_use(index_map)
        return index_map

    def _renumber_land_use(self, index_map: IntArray) -> None:
        """Renumbers the land use index map so index ids start from 1 and don't have gaps.

        Gaps in numbering can occur if there are multiple land use files which overlap. The last file governs in the
        overlapped area.
        """
        # Fix index map
        index_map_ids = _non_zero_ids_sorted(index_map)
        old_to_new_ids = {old: i + 1 for i, old in enumerate(index_map_ids)}
        for i in range(len(index_map)):
            if index_map[i] != 0:
                index_map[i] = old_to_new_ids[index_map[i]]

        # Fix self._lu_code_ids
        new_lu_code_ids: dict[int, int] = {}
        for lu_code, index_id in self._lu_code_ids.items():
            if index_id in old_to_new_ids:
                new_lu_code_ids[lu_code] = old_to_new_ids[index_id]
        self._lu_code_ids = new_lu_code_ids

        # Fix self._id_lu_codes
        new_id_lu_codes: IdLuCodes = {}
        for index_id, lu_code_tuple in self._id_lu_codes.items():
            if index_id in old_to_new_ids:
                new_id_lu_codes[old_to_new_ids[index_id]] = lu_code_tuple
        self._id_lu_codes = new_id_lu_codes

        self._next_lu_code_id = len(old_to_new_ids) + 1  # Probably unneeded since we're done with this but just in case

    def _index_map_from_raster(self, raster_path: str, csv_file_path: str) -> 'IntArray | None':
        """Creates and returns an index map by getting the pixel that each cell center is in."""
        self._logger.info(f'Creating index map from raster "{raster_path}" ...')

        try:
            raster = RasterInput(raster_path)
        except ValueError:
            self._logger.error(f'Could not open {raster_path} as a raster. Skipped.')
            return None

        lu_code_map = tool_util.raster_values_at_cells(
            self._cell_centers, self._on_off_cells, raster, self._inputs.vertical_units
        )
        csv_lu_code_names = _read_csv_file(csv_file_path)

        # Create index map from lu code map, and save index id -> lu code, lu name
        index_map = np.zeros(len(lu_code_map), dtype=int)
        for i in range(len(index_map)):
            lu_code = lu_code_map[i]
            if lu_code > 0:
                index_id = self._index_id_from_lu_code(lu_code)
                index_map[i] = index_id
                self._id_lu_codes[index_id] = lu_code, csv_lu_code_names.get(lu_code, '')

        return index_map

    def _combine_index_maps(self, soil_index_map: IntArray, land_use_index_map: IntArray) -> tuple[IntArray, ComboIds]:
        """Combines the shapefile and raster index maps, and returns a new index map."""
        self._logger.info('Combining shapefile and raster index maps...')

        index_map = np.zeros(len(soil_index_map), dtype=int)
        combo_ids: ComboIds = {}
        next_id = 1  # The current index map id
        for cell_idx, (soil_id, land_use_id) in enumerate(zip(soil_index_map, land_use_index_map)):
            if self._on_off_cells[cell_idx]:
                index_id = combo_ids.get((soil_id, land_use_id), None)
                if index_id is None:
                    index_id, combo_ids[(soil_id, land_use_id)] = next_id, next_id
                    next_id += 1
                index_map[cell_idx] = index_id
            else:
                index_map[cell_idx] = 0
        return index_map, combo_ids


def _get_cells_in_polygon(
    cell_centers: Pt3dArray, on_off_cells: IntArray, polygon: PolygonPts, coord_trans
) -> IntArray:
    """Returns an array of flags indicating which cells are in the polygon."""
    # Do outer polygon
    if coord_trans:
        outer_polygon = gu.transform_points(polygon[0], coord_trans)
    else:
        outer_polygon = polygon[0]
    in_poly = geometry.run_parallel_points_in_polygon(cell_centers, np.asarray(outer_polygon), on_off_cells)

    if max(in_poly) == 0:  # Nothing in the polygon. We can skip checking holes
        return in_poly

    # Do holes
    for inner_poly in polygon[1:]:
        if coord_trans:
            inner_poly = gu.transform_points(inner_poly, coord_trans)
        in_hole = geometry.run_parallel_points_in_polygon(cell_centers, np.asarray(inner_poly), on_off_cells)
        in_poly = np.logical_xor(in_poly, in_hole)
    return in_poly


def _non_zero_ids_sorted(index_map: IntArray) -> list[int]:
    """Returns a list of the unique, non-zero ids, sorted."""
    index_ids = set(index_map)
    if 0 in index_ids:
        index_ids.remove(0)
    index_ids = sorted(list(index_ids))
    return index_ids


def _description1_list(index_map_ids: list[int] | ComboIds, texture_ids: dict[str, int]) -> list[str]:
    """Returns the list of strings for the Description1 column, which is for soil data."""
    id_textures = {index_id: texture for texture, index_id in texture_ids.items()}
    if isinstance(index_map_ids, dict):
        d1 = [id_textures.get(soil_id, '') for soil_id, _land_use_id in index_map_ids.keys()]
    else:
        d1 = [id_textures.get(index_id, '') for index_id in index_map_ids]
    return d1


def _description2_list(index_map_ids: list[int] | ComboIds, id_lu_codes: dict[int, tuple[int, str]]) -> list[str]:
    """Returns the list of strings for the Description1 column, which is for land use."""
    d2 = []
    if isinstance(index_map_ids, dict):
        for _, land_use_id in index_map_ids.keys():
            lu_code, lu_name = id_lu_codes.get(land_use_id, (-1, ''))
            d2.append(f'Land ID #{lu_code}, {lu_name}' if lu_name else f'Land ID #{lu_code}')
    else:
        for land_use_id in index_map_ids:
            lu_code, lu_name = id_lu_codes.get(land_use_id, (-1, ''))
            d2.append(f'Land ID #{lu_code}, {lu_name}' if lu_name else f'Land ID #{lu_code}')
    return d2


def _read_csv_file(csv_file_path: str) -> dict[int, str]:
    """Reads the csv file and returns a dict of lu code -> lu name."""
    lu_codes = {}
    with open(csv_file_path, newline='') as csv_file:
        csv_reader = csv.reader(csv_file, skipinitialspace=True)
        for words in csv_reader:
            lu_codes[int(words[0])] = words[1]
    return lu_codes


def _read_csv_file_or_log_error(csv_file_path: str, logger: logging.Logger) -> dict[int, str]:
    """Reads the CSV file and returns the dict of lu codes -> names, or {} if it fails.

    Args:
        csv_file_path: Path to .csv file.
        logger: The logger.

    Returns:
        See description.
    """
    if not csv_file_path:
        logger.error('CSV path not specified.')
        return {}
    if not Path(csv_file_path).exists():
        logger.error(f'CSV file "{csv_file_path}" does not exist.')
        return {}

    csv_lu_code_names = _read_csv_file(csv_file_path)
    if not csv_lu_code_names:
        logger.error(f'No land use data could be read from "{csv_file_path}"')
        return {}
    return csv_lu_code_names


def _index_id_from_value(value: str | int, value_ids: dict[str, int] | dict[int, int], next_id: int) -> tuple[int, int]:
    """Given the value (texture or lu code), returns the index map id and the next index map id.

    Args:
        value: value from the shapefile or raster.

    Returns:
        See description.
    """
    index_id = value_ids.setdefault(value, next_id)
    if index_id == next_id:
        next_id += 1
    return index_id, next_id
