"""ModflowMapper class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules

# 3. Aquaveo modules
from xms.core.filesystem import filesystem as fs
from xms.data_objects.parameters import Coverage

# 4. Local modules
from xms.mf6.geom import shapefile_geom
from xms.mf6.mapping.grid_intersector import GridIntersector
from xms.mf6.mapping.mapper_inputs import AttFiles, MapperInputs
from xms.mf6.mapping.maw_package_builder import MawPackageBuilder
from xms.mf6.mapping.package_builder import PackageBuilder
from xms.mf6.mapping.package_builder_base import CovMapInfo, CovMapInfoList, PackageBuilderInputs
from xms.mf6.mapping.shapefile_exporter import ShapefileExporter
from xms.mf6.mapping.transient_data_exporter import TimeInfo, TransientDataExporter
from xms.mf6.misc import log_util


def map_to_modflow(inputs: MapperInputs) -> tuple[bool, dict]:
    """Does Map from Coverage.

    Args:
        inputs: Everything needed to do the mapping.

    Returns:
        (tuple[bool, dict]) True if successful, output data to add to Query.
    """
    mapper = ModflowMapper(inputs)
    return mapper.map_to_modflow()


class ModflowMapper:
    """Handles Map to MODFLOW."""
    def __init__(self, inputs: MapperInputs):
        """Initializes the class.

        Args:
            inputs: Everything needed to do the mapping.
        """
        self._inputs: MapperInputs = inputs

        self._log = log_util.get_logger()
        self._cogrid = None
        self._cov: Coverage | None = None  # Current coverage
        self._att_files: AttFiles | None = None  # Current att files
        self._cov_map_info: CovMapInfo | None = None  # Current
        self._cov_map_info_list: CovMapInfoList = []

    def map_to_modflow(self) -> tuple[bool, dict]:
        """Maps the coverage to the package.

        Returns:
            (tuple[bool, dict]) True if successful, output data to add to Query.
        """
        # Do the pipeline
        try:
            for uuid, cov_and_atts in self._inputs.coverage_data.items():
                self._cov = cov_and_atts.coverage  # Current coverage
                self._att_files = cov_and_atts.att_files  # Current att files
                self._log.info(f'Mapping coverage "{self._cov.name}" to package {self._inputs.package.pname}.')

                self._cov_map_info = CovMapInfo(uuid)
                self._convert_xms_data()
                self._set_shapefile_names()
                self._intersect_shapefile_with_grid()
                self._cov_map_info_list.append(self._cov_map_info)

            outputs = self._build_package()
            self._delete_temp_files()
        except Exception as error:
            self._log.exception(error)
            return False, {}
        return True, outputs

    def _convert_xms_data(self) -> None:
        """Convert coverage and AttTable from XMS into shapefiles and, maybe, a .csv file for transient data.

        We thought that perhaps in the future, by using more standard formats like shapefiles and .csv files, users
        could supply their own files. Thus we first convert XMS data to these formats, and then work only with the
        shapefile and .csv files from then on.
        """
        if self._cov:
            self._export_shapefiles()
            self._export_transient_data()

    def _set_shapefile_names(self) -> None:
        """If input is a shapefile, set _shapefile_names."""
        if not self._cov:
            feature_type = shapefile_geom.shapefile_feature_type(self._inputs.shapefilepath)
            self._cov_map_info.shapefile_names[feature_type] = self._inputs.shapefilepath

    def _delete_temp_files(self):
        """Deletes the temporary files."""
        for cov_map_info in self._cov_map_info_list:
            for filename in cov_map_info.shapefile_names.values():
                fs.removefile(f'{filename}.shp')
                fs.removefile(f'{filename}.shx')
                fs.removefile(f'{filename}.dbf')
                fs.removefile(f'{filename}.csv')

        for _uuid, cov_and_atts in self._inputs.coverage_data.items():
            for filename in cov_and_atts.att_files.values():
                fs.removefile(filename)
                fs.removefile(f'{filename}.def')
                fs.removefile(f'{filename}.xy')

    def _export_shapefiles(self) -> None:
        """Exports the shapefiles."""
        exporter = ShapefileExporter(self._inputs.package.ftype, self._cov, self._att_files)
        self._cov_map_info.shapefile_names = exporter.export()

    def _export_transient_data(self) -> None:
        """Exports a CSV file with all the transient data."""
        for feature_type in self._cov_map_info.shapefile_names.keys():
            coverage_att_file = self._att_files[feature_type]
            map_info = self._inputs.package.map_info(feature_type)
            if map_info:
                tdis = self._inputs.package.mfsim.tdis
                time_info = TimeInfo(
                    tdis.get_period_times(as_date_times=False), tdis.get_start_date_time(), tdis.get_time_units()
                )
                exporter = TransientDataExporter(coverage_att_file, map_info, time_info)
                exporter.export()

    def _intersect_shapefile_with_grid(self) -> None:
        """Intersects the shapefile with the grid."""
        self._cogrid = self._inputs.package.model.get_cogrid()
        intersector = GridIntersector(self._cogrid, hfb=self._inputs.package.ftype == 'HFB6')
        self._cov_map_info.ix_info = intersector.intersect(
            self._inputs.package.model.get_dis(), self._cov_map_info.shapefile_names
        )

        # Save the intersection info for
        # import pickle
        # fname = 'C:/temp/intersection_info.pkl'
        # with open(fname, 'wb') as f:
        #     pickle.dump(self._ix_info, f)

    def _build_package(self):
        """Builds the package from the intersection information.

        Returns:
            (dict): Stuff to add to the Query after mapping (UGrid edits for now).
        """
        builder_inputs = PackageBuilderInputs(
            self._inputs.package, self._inputs.append_or_replace, self._cov_map_info_list, self._cogrid,
            self._get_idomain(), self._inputs.override_layers, self._inputs.layer_filepath
        )
        if self._inputs.package.ftype == 'MAW6':
            builder = MawPackageBuilder(builder_inputs)
        else:
            builder = PackageBuilder(builder_inputs)
        builder.build()
        return builder.outputs

    def _get_idomain(self):
        """Returns the IDOMAIN from the DIS package, if there is one.

        Returns:
            (list): See description.
        """
        dis = self._inputs.package.model.get_dis()
        griddata = dis.block('GRIDDATA')
        if dis and griddata.has('IDOMAIN'):
            return [int(v) for v in griddata.array('IDOMAIN').get_values()]
        return None
