"""DefaultPackageCreator class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from contextlib import contextmanager
import logging
import os
from pathlib import Path
import shutil

# 2. Third party modules

# 3. Aquaveo modules
from xms.constraint import Grid, read_grid_from_file
from xms.core.filesystem import filesystem as fs
from xms.data_objects.parameters import Projection, UGrid as DoGrid
from xms.testing.type_aliases import Pathlike

# 4. Local modules
from xms.mf6.components import dis_builder
from xms.mf6.data import data_util, package_factory
from xms.mf6.data.base_file_data import BaseFileData
from xms.mf6.data.data_type_aliases import DisX
from xms.mf6.data.data_util import extension_from_ftype
from xms.mf6.data.grid_info import DisEnum
from xms.mf6.data.mfsim_data import MfsimData
from xms.mf6.file_io import io_factory
from xms.mf6.file_io.writer_options import WriterOptions
from xms.mf6.misc import log_util


class DefaultPackageCreator:
    """Creates default MODFLOW packages for a new simulation."""
    def __init__(self):
        """Initializes the class."""
        self._folder = ''  # Folder where files will be created.
        self._prefix = ''  # Prefix of the files to be created.
        self._extension = ''  # Package extension
        self.running_creator_tests = False  # indicates if tests are running
        self._base_model_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), 'resources', 'base_model'))
        self._log = log_util.get_logger()

    def create_simulation(self, folder, data, dogrid: DoGrid, model_str) -> tuple[MfsimData, list]:
        """Given a folder, a prefix, and list of ftypes, create default files.

        Args:
            folder (str): Folder where files will be created.
            data (dict): Dict with info on how to create the simulation from the New Simulation dialog.
            dogrid: The data_objects grid.
            model_str (str): If coming from Add Package > GWF (or GWT), is 'GWF' or 'GWT' respectively

        Returns:
            MfsimData and list of messages.
        """
        messages = []

        if not os.path.isdir(folder):
            os.mkdir(folder)

        # Set the correct DIS package
        cogrid = None
        if dogrid and dogrid.cogrid_file != '':
            cogrid, message = read_cogrid(dogrid.cogrid_file)
            if message:
                messages.append(message)
                return None, messages
        dis_type = dis_builder.get_grid_dis_type(cogrid)
        if dis_type != 'DIS6':
            if 'gwf_ftypes' in data:
                data['gwf_ftypes'].remove('DIS6')
                data['gwf_ftypes'].add(dis_type)
            if 'gwt_ftypes' in data:
                data['gwt_ftypes'].remove('DIS6')
                data['gwt_ftypes'].add(dis_type)
            if 'gwe_ftypes' in data:
                data['gwe_ftypes'].remove('DIS6')
                data['gwe_ftypes'].add(dis_type)
            if 'prt_ftypes' in data:
                data['prt_ftypes'].remove('DIS6')
                data['prt_ftypes'].add(dis_type)

        dtype = {'DIS6': 'dis', 'DISV6': 'disv', 'DISU6': 'disu'}
        mfsim_nam = os.path.join(self._base_model_dir, dtype[dis_type], 'mfsim.nam')
        reader = io_factory.reader_from_ftype('MFSIM6')
        with disable_logging() as _:  # Don't display all the log messages we like to see when importing
            mfsim = reader.read(filename=mfsim_nam)
        mfsim.pname = data['name']

        # Remove things that aren't needed
        self._remove_unneeded_exchanges(mfsim, data)
        self._remove_solution_if_not_needed(mfsim, data)
        self._remove_unused_packages(mfsim, data, model_str)

        # Replace default DIS package with one that matches the grid. Do this early so that other packages (zones)
        # can be setup correctly.
        projection = dogrid.projection if dogrid else None
        self._replace_dis(cogrid, projection, mfsim, dis=None)

        if 'gwf_ftypes' in data and 'ZONE6' in data['gwf_ftypes']:
            self._add_zone_file(mfsim)
        pobs_in_gwf = 'gwf_ftypes' in data and 'POBS6' in data['gwf_ftypes']
        pobs_in_gwt = 'gwt_ftypes' in data and 'POBS6' in data['gwt_ftypes']
        pobs_in_gwe = 'gwe_ftypes' in data and 'POBS6' in data['gwe_ftypes']
        if pobs_in_gwf or pobs_in_gwt or pobs_in_gwe:
            self._add_pest_obs_file(mfsim)

        # Set fnames before writing so they get written to settings json files
        self._set_fnames(mfsim, data, model_str)

        # write sim with each package in a separate folder
        writer_options = WriterOptions(
            use_open_close=True, dmi_sim_dir=os.path.join(folder, '..'), running_io_tests=self.running_creator_tests
        )
        writer = io_factory.writer_from_ftype('MFSIM6', writer_options)
        with disable_logging() as _:  # Don't display all the log messages we like to see when importing
            writer.write(mfsim)

        # Move files to the original component folder and delete the folder we just created
        fs.clear_folder(folder)
        if model_str:
            filename = mfsim.models[0].filename
            component_filename = os.path.join(folder, os.path.basename(filename))
            # shutil.copyfile(filename, component_filename)
            shutil.copytree(os.path.dirname(filename), os.path.dirname(component_filename), dirs_exist_ok=True)
            shutil.rmtree(os.path.dirname(filename))
            mfsim.models[0].filename = component_filename
            shutil.rmtree(os.path.dirname(mfsim.filename))  # Also delete the sim folder
        else:
            # shutil.copyfile(WriterOptions._dmi_file_list[0], os.path.join(folder, 'mfsim.nam'))
            shutil.copytree(os.path.dirname(mfsim.filename), folder, dirs_exist_ok=True)
            shutil.rmtree(os.path.dirname(WriterOptions._dmi_file_list[0]))
            WriterOptions._dmi_file_list[0] = os.path.join(folder, 'mfsim.nam')
        return mfsim, messages

    def create_package(
        self, mfsim_dir: Pathlike, filename: Pathlike, unique_name, parent_name, dis_enum, overwrite=True, model=None
    ):
        """Creates a new package by copying a default file to the new location.

        Not called when importing a simulation, or on New Simulation. Only when adding a package to an existing
        simulation.

        Args:
            mfsim_dir: Path to directory of mfsim.nam.
            filename: Filepath of file to be created.
            unique_name: tree_node.unique_name, which is the same as ftype except with EVTA6 and RCHA6.
            parent_name (str): Name of the parent tree node.
            dis_enum (DisEnum): Tells what type of DIS/DISV/DISU we're dealing with.
            overwrite (bool): If true and a file with the name we want already exists, overwrites it. Otherwise finds
                a unique filename.
            model: Tells us whether the package is a flow, transport, or energy package.

        Returns:
            (str): new main file (absolute paths)
        """
        ftype = data_util.fix_ftype(unique_name, with_the_a=False)
        folder = os.path.dirname(filename)
        if not os.path.isdir(folder):
            os.mkdir(folder)

        # Get path to default mfsim.nam
        dis_name = {DisEnum.DIS: 'dis', DisEnum.DISV: 'disv', DisEnum.DISU: 'disu'}[dis_enum]
        mfsim_nam = Path(self._base_model_dir) / dis_name / 'mfsim.nam'

        # This function is NOT called when importing a simulation or on New Simulation - those operations use
        # DefaultPackageCreator.create_simulation. On New Simulation we do look at the UGrid and create a DIS, DISV or
        # DISU package based on the UGrid.
        #
        # This function is only called if a new package (including a new GWF model) is added to an existing simulation.
        # All packages except for ZONES and GWF models have the same default files regardless of DIS, DISV or DISU, so
        # it doesn't matter which DIS* it is.
        #
        # If ftype == DIS*, we need to create it for the current UGrid.
        #
        # If ftype == ZONE6, we have to know the size of the current grid when the package is created.
        #
        # If ftype == GWF6, we don't know when it's created whether it will have a DIS, DISV, or DISU package
        # - the user will have to add that later and add a UGrid link.

        # Read the default sim from the resources directory
        reader = io_factory.reader_from_ftype('MFSIM6')
        mfsim_defaults = reader.read(filename=mfsim_nam)

        # Replace default DIS package with one that matches the grid. Do this early so that other packages (zones)
        # can be setup correctly.
        if model:
            if ftype in {'DIS6', 'DISV6', 'DISU6'}:
                # Creating a DIS* package. Create from scratch using the grid
                dis = None
                cogrid = model.get_cogrid()
            else:
                # Not creating DIS* package. Use existing DIS* package so we don't have to read the grid
                dis = model.get_dis()
                cogrid = None
            self._replace_dis(cogrid=cogrid, projection=None, mfsim=mfsim_defaults, dis=dis)

        # Find the default package with the correct ftype.
        packages = []
        # First see if we can get it from the model, then try the sim
        if model:
            # Find the default model whose ftype matches the model passed in
            for default_model in mfsim_defaults.models:
                if default_model.ftype == model.ftype:
                    packages = default_model.packages_from_ftype(ftype)
                    break
        if not packages:
            packages = mfsim_defaults.packages_from_ftype(ftype)

        packages = _keep_readasarrays(unique_name, packages)
        if not packages and model:
            # Create it from scratch
            if model.ftype == 'GWF6':
                packages.append(
                    package_factory.package_from_ftype(ftype, mfsim=mfsim_defaults, model=mfsim_defaults.models[0])
                )
            elif model.ftype == 'GWT6':  # pragma no cover - Currently no GMS specific packages for GWT yet
                packages.append(
                    package_factory.package_from_ftype(
                        ftype,
                        mfsim=mfsim_defaults,  # pragma no cover
                        model=mfsim_defaults.models[1]
                    )
                )  # pragma no cover
            else:  # pragma no cover - Currently no GMS specific packages for GWE yet (like POBS6, ZONE6)
                packages.append(
                    package_factory.package_from_ftype(
                        ftype,
                        mfsim=mfsim_defaults,  # pragma no cover
                        model=mfsim_defaults.models[2]
                    )
                )  # pragma no cover
            packages[-1]._base._grid_info = model.grid_info()  # Hack it to use correct GridInfo (for ZONE package)

        if not packages:
            packages.append(
                package_factory.package_from_ftype(ftype, mfsim=mfsim_defaults, model=mfsim_defaults.models[0])
            )

        _reset_exchange_mnames(packages[0])
        _reset_solutiongroup_mnames(packages[0])

        # Write it to the new location. Don't just copy the file as we want to write the settings.

        # All component main files will be of the form 'model.ext', but the fnames can be more unique
        orig_name = packages[0].filename
        model_dot_extension = f'model{os.path.splitext(orig_name)[1]}'
        candidate = os.path.join(folder, model_dot_extension)
        if overwrite:
            packages[0].filename = candidate
        else:
            packages[0].filename = fs.make_filename_unique(candidate)
        packages[0].fname = self._make_fname(parent_name, os.path.basename(packages[0].filename))
        dmi_sim_dir = os.path.normpath(os.path.join(folder, '..'))
        writer_options = WriterOptions(
            mfsim_dir=mfsim_dir,
            use_open_close=True,
            dmi_sim_dir=dmi_sim_dir,
            running_io_tests=self.running_creator_tests
        )
        writer = io_factory.writer_from_ftype(ftype, writer_options)
        writer.write(packages[0])
        return packages[0].filename

    def _replace_dis(self, cogrid: Grid, projection: Projection, mfsim: MfsimData, dis: DisX | None):
        """Replaces the DIS* packages in the simulation with one that matches the UGrid.

        Args:
            cogrid: The constrained grid.
            projection: The projection.
            mfsim: MfsimData object.
            dis: The dis package to use as the replacement.
        """
        new_dis = None
        for model in mfsim.models:
            if projection:
                model.projection_wkt = projection.well_known_text
            if dis:
                model.packages[0] = dis
            else:
                if not new_dis:  # Only build the dis package once since it can take time and will be the same
                    if cogrid:
                        units = projection.horizontal_units
                        new_dis = dis_builder.build_dis_package(cogrid, cogrid.ugrid, model.get_dis(), units)
                        model.packages[0] = new_dis
                else:  # Copy the package
                    model.packages[0] = new_dis.make_copy()
                    new_dis.model = model

    def _make_fname(self, parent_name, filename):
        """Returns the fname of the new package.

        The fname is currently hardcoded to be the {model name}.ext, if in a model, or {sim name}.ext if under sim.

        Args:
            parent_name (str): Name of the parent tree node.

        Returns:
            fname (str): The fname.
        """
        fname = os.path.basename(filename.replace('model', parent_name))
        return fname

    def _set_fnames(self, mfsim, data, model_str):
        """Sets base.fname for the sim and all packages in it.

        Args:
            mfsim (MfsimData): MfsimData object.
            data (dict): Dict with info on how to create the simulation from the New Simulation dialog.
            model_str (str): If coming from Add Package > GWF (or GWT), is 'GWF' or 'GWT' respectively
        """
        mfsim.fname = self._make_fname(data['name'], mfsim.filename)
        if mfsim.tdis:
            mfsim.tdis.fname = self._make_fname(data['name'], mfsim.tdis.filename)
        if mfsim.exchanges:
            for exchange in mfsim.exchanges:
                exchange.fname = self._make_fname(data['name'], exchange.filename)
        if mfsim.solution_groups:
            for solution_group in mfsim.solution_groups:
                for solution in solution_group.solution_list:
                    orig_name = solution.filename
                    extension = extension_from_ftype(solution.ftype)
                    solution.filename = os.path.join(os.path.dirname(solution.filename), f'model{extension}')
                    solution.fname = self._make_fname(data['name'], orig_name)
        for model in mfsim.models:
            if model_str:
                model_name = data['name']
            else:
                model_name = {'GWF6': 'flow', 'GWE6': 'energy', 'GWT6': 'trans', 'PRT6': 'track'}[model.ftype]
            model.pname = model_name
            model.fname = self._make_fname(model_name, model.filename)
            for package in model.packages:
                package.fname = self._make_fname(model_name, package.filename)

    def _add_zone_file(self, mfsim):
        """Adds the .zon file to the simulation.

        Args:
            mfsim(MfsimData): MfsimData object.
        """
        zone_file = os.path.join(self._base_model_dir, 'files', 'model.zon')
        reader = io_factory.reader_from_ftype('ZONE6')
        zone_data = reader.read(filename=zone_file, mfsim=mfsim, model=mfsim.models[0])
        mfsim.models[0].packages.append(zone_data)

    def _add_pest_obs_file(self, mfsim):
        """Adds the PEST observation file to the simulation.

        Args:
            mfsim(MfsimData): MfsimData object.
        """
        pest_obs_file = os.path.join(self._base_model_dir, 'files', 'model.pobs')
        reader = io_factory.reader_from_ftype('POBS6')
        pest_obs_data = reader.read(filename=pest_obs_file, mfsim=mfsim, model=mfsim.models[0])
        mfsim.models[0].packages.append(pest_obs_data)

    def _remove_unused_packages(self, mfsim, data, model_str):
        """Remove packages that should not be included.

        Args:
            mfsim (MfsimData): modflow simulation.
            data (dict): Dict with info on how to create the simulation from the New Simulation dialog.
            model_str (str): If coming from Add Package > GWF (or GWT), is 'GWF' or 'GWT' respectively
        """
        if model_str:  # Adding a model, not a sim, so get rid of these sim level packages
            mfsim.tdis = None
            mfsim.exchanges = []
            mfsim.solution_groups = []

        # Remove the GWF or GWT if necessary
        mfsim.models = [
            m for m in mfsim.models if (
                (m.ftype == 'GWF6' and 'gwf_ftypes' in data)  # noqa W503
                or (m.ftype == 'GWT6' and 'gwt_ftypes' in data)  # noqa W503
                or (m.ftype == 'GWE6' and 'gwe_ftypes' in data)  # noqa W503
                or (m.ftype == 'PRT6' and 'prt_ftypes' in data)  # noqa W503
            )
        ]  # noqa W503

        for model in mfsim.models:
            # Rename the model if necessary
            if model_str:
                model.pname = data['name']

            # Get the ftypes of the packages to include
            ftypes = {
                'GWF6': data.get('gwf_ftypes'),
                'GWE6': data.get('gwe_ftypes'),
                'GWT6': data.get('gwt_ftypes'),
                'PRT6': data.get('prt_ftypes')
            }[model.ftype]

            packages = []
            for p in model.packages:
                package_ftype = data_util.fix_ftype(p.ftype, with_the_a=p.readasarrays)
                if package_ftype in ftypes:
                    packages.append(p)
            model.packages = packages

    def _remove_unneeded_exchanges(self, mfsim, data):
        """Removes exchanges from the simulation if they are not needed.

        Args:
            mfsim (MfsimData): The simulation.
            data (dict): Dict with info on how to create the simulation from the New Simulation dialog.
        """
        # This is only called from create_simulation(), so we will never need GWF6-GWF6, GWT6-GWT6, or GWE6-GWE6.
        exchanges = []
        for exchange in mfsim.exchanges:
            if exchange.ftype == 'GWF6-GWT6' and 'gwf_ftypes' in data and 'gwt_ftypes' in data:
                exchanges.append(exchange)
            elif exchange.ftype == 'GWF6-GWE6' and 'gwf_ftypes' in data and 'gwe_ftypes' in data:
                exchanges.append(exchange)
            elif exchange.ftype == 'GWF6-PRT6' and 'gwf_ftypes' in data and 'prt_ftypes' in data:
                exchanges.append(exchange)
        mfsim.exchanges = exchanges

    def _remove_solution_if_not_needed(self, mfsim: MfsimData, data: dict):
        """Keeps only the solution packages that are needed.

        The default sim includes 2, with the first for the flow model and the second for the trans model.

        Args:
            mfsim: The simulation.
            data: Dict with info on how to create the simulation from the New Simulation dialog.
        """
        new_solution_list = []
        if 'gwf_ftypes' in data:
            new_solution_list.append(mfsim.solution_groups[0].solution_list[0])
        if 'gwt_ftypes' in data:
            new_solution_list.append(mfsim.solution_groups[0].solution_list[1])
        if 'gwe_ftypes' in data:
            new_solution_list.append(mfsim.solution_groups[0].solution_list[2])
        if 'prt_ftypes' in data:
            new_solution_list.append(mfsim.solution_groups[0].solution_list[3])
        mfsim.solution_groups[0].solution_list = new_solution_list


@contextmanager
def disable_logging():
    """Context manager to disable logging."""
    logging.disable()
    try:
        yield
    finally:
        logging.disable(logging.NOTSET)


def read_cogrid(ugrid_filename):
    """Reads the UGrid.

    Args:
        ugrid_filename: The UGrid file.

    Returns:
        (tuple): tuple containing:
            - cogrid: The constrained grid.
            - errors: A tuple of error messages, or None if no errors.
    """
    ugrid_filename = str(ugrid_filename)
    if not os.path.isfile(ugrid_filename):
        return None, ('ERROR', f'File not found: "{ugrid_filename}"')

    log = log_util.get_logger()
    log.info('Reading the grid')
    cogrid = read_grid_from_file(ugrid_filename)

    if not cogrid:
        return cogrid, ('ERROR', f'Grid could not be read from: "{ugrid_filename}"')

    # TODO: Check if grid does not have top and bottom elevations and fix it if necessary.

    log.info('Checking that the grid is 3D and vertically prismatic')
    if not cogrid.check_all_cells_3d():
        return cogrid, ('ERROR', 'Only 3D UGrids are used with MODFLOW (all cells must be 3D cells).')
    if not cogrid.check_all_cells_vertically_prismatic():
        return cogrid, ('ERROR', 'Invalid UGrid. All cells must be vertically prismatic.')
    return cogrid, None


def _reset_exchange_mnames(package) -> None:
    """If package is an exchange, set exgmnamea and exgmnameb to '' because we don't know what they should be yet.

    Args:
        package: The package.
    """
    if package.ftype in data_util.exchange_ftypes():
        package.exgmnamea = ''
        package.exgmnameb = ''


def _reset_solutiongroup_mnames(package) -> None:
    """If package is IMS6 or EMS6, reset slnmnames to [] because we don't know what they should be yet.

    Args:
        package: The package.
    """
    if package.ftype in {'IMS6', 'EMS6'}:
        package.slnmnames = []


def _keep_readasarrays(unique_name: str, packages: list[BaseFileData]) -> list[BaseFileData]:
    """If unique_name is among those that can have READASARRAYS, remove the list-based or array-based versions.

    Args:
        unique_name: tree_node.unique_name, which is the same as ftype except with EVTA6 and RCHA6.
        packages: List of packages.

    Returns:
        The list, possibly altered.
    """
    new_packages = []
    readasarrays_ftype_with = unique_name in data_util.readasarrays_ftypes(with_the_a=True)
    readasarrays_ftype_without = unique_name in data_util.readasarrays_ftypes(with_the_a=False)
    if readasarrays_ftype_with or readasarrays_ftype_without:
        for p in packages:
            if readasarrays_ftype_with and p.readasarrays or readasarrays_ftype_without and not p.readasarrays:
                new_packages.append(p)
    else:
        new_packages = packages
    return new_packages
