"""ShapefileExporter class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os
from pathlib import Path
from typing import Any
import warnings

# 2. Third party modules
import numpy as np
import pandas as pd
import shapefile  # From pyshp

# 3. Aquaveo modules
from xms.coverage.polygons import polygon_orienteer
from xms.data_objects.parameters import FilterLocation

# 4. Local modules
from xms.mf6.mapping import map_util
from xms.mf6.misc import log_util


class ShapefileExporter:
    """Exports a coverage to a shapefile."""
    def __init__(self, ftype, coverage, coverage_att_files):
        """Initializes the class.

        Args:
            ftype (str): The file type used in the GWF name file (e.g. 'WEL6')
            coverage (xms.data_objects.parameters.Coverage): The coverage
            coverage_att_files (dict): dict with the attribute files.
        """
        self._ftype = ftype
        self._coverage = coverage
        self._coverage_att_files = coverage_att_files

        self._float_decimals = 16  # Width of float fields
        self._fields_dict = {}  # Dict of bad field names and their good counterparts
        self._shapefile_names = {}  # Dict of shapefile names

    def export(self):
        """Maps the coverage to the package.

        Returns:
            (dict) Dict of shapefile names, keyed by feature type.
        """
        logger = log_util.get_logger()
        logger.info('Exporting shapefile')
        self._init_field_dict()
        self._shapefiles_from_coverage()
        return self._shapefile_names

    def _init_field_dict(self):
        """Initializes the dict that contains bad field names and their good counterparts.

        Shapefile field names must be 10 characters max, and they should match standard MODFLOW names.
        """
        self._fields_dict = map_util.standard_fields()

    def _create_fields(self, table_def, shapefile_writer):
        """Creates the fields in the shapefile based on the table definition.

        Args:
            table_def (dict): Table definition dict.
            shapefile_writer (shapefile.Writer): Shapefile writer.
        """
        for column in table_def['columns']:
            if column['type'] == 'int':
                shapefile_writer.field(column['name'], 'N', decimal=0)
            elif column['type'] in ['double', 'float']:
                shapefile_writer.field(column['name'], 'N', decimal=self._float_decimals)
            elif column['type'] == 'string':
                shapefile_writer.field(column['name'], 'C')
            elif column['type'] == 'bool':
                shapefile_writer.field(column['name'], 'L')
            else:
                raise RuntimeError('Unknown column type in table definition file.')

    def _add_arc_node_fields(self, shapefile_writer):
        """Adds 'Node_1_ID' and 'Node_2_ID' fields in the shapefile.

        Args:
            shapefile_writer (shapefile.Writer): Shapefile writer.
        """
        shapefile_writer.field('Node_1_ID', 'N', decimal=0)
        shapefile_writer.field('Node_2_ID', 'N', decimal=0)

    def _add_arc_node_ids(self, arc, row):
        """Adds the values for the 'Node_1_ID' and 'Node_2_ID' columns to the row.

        Args:
            arc (xms.data_objects.parameters.Arc): An arc.
            row (list): List of values for the row.
        """
        n1 = arc.start_node
        row.append(n1.id)
        n2 = arc.end_node
        row.append(n2.id)

    def _atts_exist(self, feature_type):
        """Returns True if the attribute table file and table definition file both exist.

        Args:
            feature_type (str): 'points', 'arcs', or 'polygons'.

        Returns:
            See description.
        """
        if feature_type not in self._coverage_att_files:
            return False
        att_table_file = self._coverage_att_files[feature_type]
        return os.path.isfile(att_table_file) and os.path.isfile(f'{att_table_file}.def')

    def _coverage_data_exists(self, feature_type):
        """Returns True if the feature objects of the given type exist and attribute data exists.

        Args:
            feature_type (str): 'points', 'arcs', or 'polygons'.

        Returns:
            See description.
        """
        if feature_type == 'points':
            points = self._coverage.get_points(FilterLocation.PT_LOC_DISJOINT)
            return points and self._atts_exist(feature_type)
        elif feature_type == 'arcs':
            return self._coverage.arcs and self._atts_exist(feature_type)
        elif feature_type == 'polygons':
            return self._coverage.polygons and self._atts_exist(feature_type)
        else:
            return False

    def _shapefiles_from_coverage(self):
        """Exports a shapefile and, if transient, a .csv file from the coverage."""
        if self._coverage_data_exists('points'):
            self._create_point_shapefile()
        if self._coverage_data_exists('arcs'):
            self._create_arc_shapefile()
        if self._coverage_data_exists('polygons'):
            self._create_polygon_shapefile()

    def _get_atts_and_fix_field_names(self, feature_type) -> tuple[str, dict, dict]:
        """Returns the attribute table file name and table definition dict with the field names fixed.

        Args:
            feature_type (str): 'points', 'arcs', or 'polygons'.

        Returns:
            (tuple): tuple containing:
                - Attribute table file name.
                - Table definition dict.
        """
        att_filename = self._coverage_att_files[feature_type]
        table_def = map_util.read_table_definition_file(att_filename)
        shape_table_def = map_util.fix_field_names(table_def, self._fields_dict)
        return att_filename, table_def, shape_table_def

    def _fix_field_names_and_set_shapefile_name(self, feature_type) -> tuple[str, dict, dict]:
        """Changes field names from those exported in the att table by GMS to ones that work in a shapefile.

        Args:
            feature_type (str): 'points', 'arcs', or 'polygons'.

        Returns:
            (tuple): tuple containing:
                - Attribute table file name.
                - Table definition dict.
        """
        att_filename, table_def, shape_table_def = self._get_atts_and_fix_field_names(feature_type)
        self._shapefile_names[feature_type] = os.path.splitext(att_filename)[0]
        return att_filename, table_def, shape_table_def

    def _create_point_shapefile(self):
        """Creates a point shapefile from the coverage and coverage atts.

        Shapefile will match order of coverage, which is not necessarily order of att table.
        """
        feature_type = 'points'
        att_filename, table_def, shape_table_def = self._fix_field_names_and_set_shapefile_name(feature_type)
        df = _read_att_table_to_dataframe(att_filename, table_def)

        with shapefile.Writer(self._shapefile_names[feature_type], shapeType=shapefile.POINTZ) as w:
            self._create_fields(shape_table_def, w)
            for point in self._coverage.get_points(FilterLocation.PT_LOC_DISJOINT):
                w.pointz(point.x, point.y, point.z)  # Creates the shape
                row = _row_with_feature_id(df, point.id)
                w.record(*row)  # Creates the record

    def _create_arc_shapefile(self):
        """Creates an arc shapefile from the coverage and coverage atts.

        Shapefile will match order of coverage, which is not necessarily order of att table.

        Note:
             Two extra columns are added: 'Node_1_ID', 'Node_2_ID' so connectivity can be determined.
        """
        feature_type = 'arcs'
        att_filename, table_def, shape_table_def = self._fix_field_names_and_set_shapefile_name(feature_type)
        df = _read_att_table_to_dataframe(att_filename, table_def)

        with shapefile.Writer(self._shapefile_names[feature_type], shapeType=shapefile.POLYLINEZ) as w:
            self._create_fields(shape_table_def, w)
            self._add_arc_node_fields(w)
            for arc in self._coverage.arcs:
                arc_points = arc.get_points(FilterLocation.PT_LOC_ALL)
                line = [[p.x, p.y, p.z] for p in arc_points]
                line_parts = [line]
                w.linez(line_parts)
                row = _row_with_feature_id(df, arc.id)
                self._add_arc_node_ids(arc, row)
                w.record(*row)

    def _create_polygon_shapefile(self):
        """Creates a polygon shapefile from the coverage and coverage atts.

        Shapefile will match order of coverage, which is not necessarily order of att table.
        """
        feature_type = 'polygons'
        att_filename, table_def, shape_table_def = self._fix_field_names_and_set_shapefile_name(feature_type)
        df = _read_att_table_to_dataframe(att_filename, table_def)

        with shapefile.Writer(self._shapefile_names[feature_type], shapeType=shapefile.POLYGONZ) as w:
            self._create_fields(shape_table_def, w)
            for polygon in self._coverage.polygons:
                polygon_points = polygon_orienteer.get_polygon_point_lists(polygon)
                w.polyz(polygon_points)  # Creates the shape
                row = _row_with_feature_id(df, polygon.id)
                w.record(*row)  # Creates the record


def _read_att_table_to_dataframe(att_filename: Path | str, table_def: dict) -> pd.DataFrame:
    """Returns a pandas dataframe created by reading the att table file, with feature ID column set as index.

    We use a pd.dataframe to lookup rows by feature ID since att file may not match coverage order.

    Args:
        att_filename: The att table file path.
        table_def: Dict of the table definition file.

    Returns:
        See description.
    """
    # Set up the dtypes
    string_to_type = {'int': np.int32, 'float': np.float64, 'double': np.float64, 'string': object}
    dtypes = {col['name']: string_to_type[col['type']] for col in table_def['columns']}

    # keep_default_na=False preserves empty strings instead of replacing them with nan
    df = pd.read_csv(att_filename, header=0, float_precision='high', dtype=dtypes, keep_default_na=False)

    # Use feature ID as dataframe index for speed
    df.set_index('ID', drop=False, inplace=True)
    return df


def _row_with_feature_id(df: pd.DataFrame, feature_id: int) -> list[Any]:
    """Returns the dataframe row with the given feature ID, as a list.

    Args:
        df: The dataframe.
        feature_id: The feature ID.

    Returns:
        See description.
    """
    # This code would yield the following warning:
    #
    # DeprecationWarning('np.find_common_type is deprecated.  Please use `np.result_type` or `np.promote_types`.
    # See https://numpy.org/devdocs/release/1.25.0-notes.html and the docs for more information.
    # (Deprecated NumPy 1.25)
    #
    # Apparently the warning is gone in pandas >= 2.0.2.
    with warnings.catch_warnings():
        warnings.simplefilter('ignore')
        return df.loc[feature_id].tolist()  # Get row by feature ID
