"""PestObsWriter class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from datetime import datetime
import os

# 2. Third party modules

# 3. Aquaveo modules
from xms.core.filesystem import filesystem as fs

# 4. Local modules
from xms.mf6.data import data_util
from xms.mf6.data.grid_info import DisEnum
from xms.mf6.file_io.list_package_writer import ListPackageWriter
from xms.mf6.file_io.pest import pest_obs_data_generator
from xms.mf6.simulation_runner import sim_runner
from xms.mf6.simulation_runner.sim_runner import ExeKey


class PestObsWriter(ListPackageWriter):
    """Writes the PEST observations file."""
    def __init__(self):
        """Initializes the class."""
        super().__init__()
        self._data = None  # PestObsData
        self._pest_sub_dir = ''
        self._files = {}  # Dict of files keyed by extension

    def _copy_pest_files(self):
        """Copies the PEST input files (.blisting, .bsamp, .n2b) from the component dir to the '<model>_pest' dir."""
        self._pest_sub_dir, _, _ = self._writer_options.output_dir.rpartition('_output')
        self._pest_sub_dir += '_pest'
        fs.make_or_clear_dir(self._pest_sub_dir)
        extensions_set = {
            '.b2b', '.b2map', '.blisting', '.bsamp', '.bwt', '.fsamp', '.fwt', '.n2b', '.mftimes-bsamp',
            '.mftimes-fsamp'
        }
        for _, _, files in os.walk(self._data.component_folder):
            for f in files:
                extension = os.path.splitext(f)[1]
                if extension in extensions_set:
                    from_file = os.path.join(self._data.component_folder, f)
                    to_file = os.path.join(self._pest_sub_dir, os.path.basename(f))
                    fs.copyfile(from_file, to_file)
                    self._files[extension] = to_file

    def _write_pest_batch_files(self):
        """Writes all the batch files used by PEST, and their corresponding .in files, and the settings.fig file."""
        pest_path = self._get_pest_path()
        self._write_batch_and_in_file(pest_path, 'mf6bud2smp.exe', 'mf6bud2smp', mftimes='')
        self._write_batch_and_in_file(pest_path, 'mf6mod2obs.exe', 'mf6mod2obs', mftimes='')
        self._write_batch_and_in_file(pest_path, 'mf6mod2obs.exe', 'mf6mod2obs', mftimes='mftimes-bsamp')
        # I don't think flows work like this:
        # self._write_batch_and_in_file(pest_path, 'mf6mod2obs.exe', 'mf6mod2obs', mftimes='mftimes-fsamp')
        self._write_batch_and_in_file(pest_path, 'smp2smp.exe', 'smp2smp', mftimes='')
        self._write_obs_batch_file()
        self._write_settings_file()

    def _write_settings_file(self):
        """Writes the "settings.fig" file used by PEST and which defines the date format."""
        settings_file_name = os.path.join(self._pest_sub_dir, 'settings.fig')
        with open(settings_file_name, 'w') as file:
            file.write('date=mm/dd/yyyy\n')
            file.write('colrow=yes\n')

    def _write_obs_batch_file(self):
        """Writes the "obs.bat" file, which calls the other three batch files."""
        batch_path = os.path.join(self._pest_sub_dir, 'obs.bat')
        with open(batch_path, 'w') as file:
            file.write('call mf6mod2obs.bat\n')
            file.write('call mf6mod2obs-mftimes-bsamp.bat\n')
            # I don't think flows work like this:
            # file.write('call mf6mod2obs-mftimes-fsamp.bat\n')
            file.write('call mf6bud2smp.bat\n')
            file.write('call smp2smp.bat\n')

    def _get_pest_path(self):
        """Returns the path to the PEST executable pest.exe."""
        exe_paths = sim_runner.get_gms_mf6_executable_paths()
        pest_path = exe_paths[ExeKey.PEST_FOR_MODFLOW_6]
        pest_path = os.path.dirname(pest_path)
        return pest_path

    def _write_batch_and_in_file(self, pest_path, exe_name, batch_filename_base, mftimes):
        """Writes a batch file and its corresponding .in file."""
        self._write_batch_file(pest_path, exe_name, batch_filename_base, mftimes)
        self._write_in_file(batch_filename_base, mftimes)

    def _write_batch_file(self, pest_path, exe_name, batch_filename_base, mftimes):
        """Writes a batch file."""
        exe_path = os.path.join(pest_path, exe_name)
        dash_mftimes = f'-{mftimes}' if mftimes else ''
        batch_path = os.path.join(self._pest_sub_dir, f'{batch_filename_base}{dash_mftimes}.bat')
        with open(batch_path, 'w') as file:
            file.write(f'"{exe_path}" < {batch_filename_base}{dash_mftimes}.in')

    def _write_in_file(self, batch_filename_base, mftimes):
        """Writes a .in file."""
        if batch_filename_base == 'mf6bud2smp':
            self._write_mf6bud2smp_in_file(batch_filename_base)
        elif batch_filename_base == 'mf6mod2obs':
            self._write_mf6mod2obs_in_file(batch_filename_base, mftimes=mftimes)
        elif batch_filename_base == 'smp2smp':
            self._write_smp2smp_in_file(batch_filename_base)

    def _file_name_line_from_extension(self, extension):
        """Returns the line needed in the .in file which gives the filename (.blisting, .bsamp, .n2b)."""
        file_name = f'{os.path.basename(self._files[extension])}\n' if extension in self._files else "''\n"
        return file_name

    def _output_filepath(self, out_file):
        """Returns the path to the '<model>_output' folder."""
        filepath = os.path.join(self._writer_options.output_dir, os.path.basename(out_file))
        return os.path.normpath(fs.compute_relative_path(self._pest_sub_dir, filepath))

    def _budget_filepath(self):
        """Returns the path to the budget file."""
        return self._output_filepath(data_util.get_budget_file_name(self._data.model))

    def _dep_var_filepath(self) -> str:
        """Returns the path to the dependent variable (head, concentration, or temperature) file."""
        return self._output_filepath(data_util.get_dv_filename(self._data.model))

    def _dis_str(self):
        """Returns 's' of using DIS, 'v' if DISV, or 'u' if DISU."""
        if self._data.grid_info().dis_enum == DisEnum.DIS:
            return 's'
        elif self._data.grid_info().dis_enum == DisEnum.DISV:
            return 'v'
        else:
            return 'u'  # pragma no cover - TODO: use PEST Obs stuff with a DISU model

    def _max_output_times(self):
        """Returns the maximum number of output times for this model."""
        output_times = pest_obs_data_generator.get_output_times(self._data.model)
        count = 0
        for _period, ts_list in output_times.items():
            count += len(ts_list)
        return count

    def _time_unit_string(self):
        """Returns the time unit string (y/d/h/m/s) by looking at the TDIS package.

        Returns:
            See description.
        """
        time_units = self._data.mfsim.tdis.get_time_units()
        if time_units == 'SECONDS':
            return 's'
        elif time_units == 'MINUTES':
            return 'm'
        elif time_units == 'HOURS':
            return 'h'
        elif time_units == 'DAYS':
            return 'd'
        elif time_units == 'YEARS':
            return 'y'
        else:
            return ''

    def _zero_date_string(self):
        """Returns the PEST zero date string, which is the TDIS START_DATE_TIME or 01/01/1950 if that's not used."""
        start_date_time = self._data.mfsim.tdis.get_start_date_time()
        if isinstance(start_date_time, datetime):
            return start_date_time.strftime(pest_obs_data_generator.PEST_STRFTIME_DATE)
        else:
            return pest_obs_data_generator.PEST_ZERO_DATE

    def _dot_samp_file_name_line(self):
        """Returns the name of the .samp file, which is an output file and doesn't exist yet."""
        samp_filename = ''
        if self._files:
            # Use the base name of the first file that we copied and append '.samp'
            samp_filename = os.path.splitext(list(self._files.values())[0])[0] + '.samp'
        file_name = f'{os.path.basename(samp_filename)}\n' if samp_filename else "''\n"
        return file_name

    def _write_mf6mod2obs_in_file(self, batch_filename_base, mftimes):
        """Writes the mf6mod2obs.in file."""
        dash_mftimes = f'-{mftimes}' if mftimes else ''
        in_path = os.path.join(self._pest_sub_dir, f'{batch_filename_base}{dash_mftimes}.in')
        dep_var_filepath = self._dep_var_filepath()
        with open(in_path, 'w') as file:
            # Enter name of node-to-bore interpolation file:
            file.write(self._file_name_line_from_extension('.n2b'))
            #  Enter name of bore listing file:
            file.write(self._file_name_line_from_extension('.blisting'))
            # Reapportion interpolation factors if any dry/inactive nodes [y/n]?
            file.write('y\n')
            #  Enter name of existing bore sample file:
            extension = f'.{mftimes}' if mftimes else '.bsamp'
            file.write(self._file_name_line_from_extension(extension))
            # Enter name of binary MF6-generated file:
            file.write(f'{dep_var_filepath}\n')
            # Enter value signifying inactive cells [1.0E30]:
            file.write('1.0E30\n')
            # Enter value signifying dry cells [-1.0E30]:
            file.write('-1.0E30\n')
            # Is the model grid structured (i.e. DIS), DISV or DISU?  [s/v/u]:
            dis_str = self._dis_str()
            file.write(f'{dis_str}\n')
            if dis_str != 'u':
                #  How many layers in the model?
                file.write(f'{self._data.grid_info().nlay}\n')
            #  Enter time units used by model (yr/day/hr/min/sec) [y/d/h/m/s]:
            file.write(f'{self._time_unit_string()}\n')
            # Enter simulation starting date [mm/dd/yyyy]:
            file.write(f'{self._zero_date_string()}\n')  # 01/01/1950 or Tdis START_DATE_TIME
            #  Enter simulation starting time [hh:mm:ss]:
            file.write('0:0:0\n')
            # If a sample time does not lie between model output times, or if there is
            #    only one model output time, value at the sample time can equal that at
            #    nearest model output time:-
            #    Enter extrapolation limit in days (fractional if necessary):
            file.write('1e20\n')
            # Enter name for bore sample output file:
            file.write(f'{os.path.basename(self._files.get(extension, ""))}.out\n')

    def _write_mf6bud2smp_in_file(self, batch_filename_base):
        """Writes the mf6bud2smp.in file."""
        in_path = os.path.join(self._pest_sub_dir, f'{batch_filename_base}.in')
        budget_filepath = self._budget_filepath()
        with open(in_path, 'w') as file:
            # Enter name of MODFLOW6 binary budget output file:
            file.write(f'{budget_filepath}\n')
            #  Enter maximum number of output times featured in this file:
            file.write(f'{self._max_output_times()}\n')
            # Enter name of bore-to-budget file:
            file.write(self._file_name_line_from_extension('.b2b'))
            # Enter simulation starting date [mm/dd/yyyy]:
            file.write(f'{self._zero_date_string()}\n')  # 01/01/1950 or Tdis START_DATE_TIME
            #  Enter simulation starting time [hh:mm:ss]:
            file.write('0:0:0\n')
            #  Enter time units employed by model [y/d/h/m/s]:
            file.write(f'{self._time_unit_string()}\n')
            # Enter name for bore sample output file:
            file.write(self._dot_samp_file_name_line())
            # Enter flow rate factor:
            file.write('1.0\n')
            # Assign flows to beginning, middle or finish of time step?  [b/m/f]:
            file.write('f\n')
            #  Enter name for run record file:
            file.write('temp.rec\n')

    def _write_smp2smp_in_file(self, batch_filename_base):
        """Writes the smp2smp.in file."""
        in_path = os.path.join(self._pest_sub_dir, f'{batch_filename_base}.in')
        with open(in_path, 'w') as file:
            # Enter name of observation bore sample file:
            file.write(self._file_name_line_from_extension('.fsamp'))
            # Enter name of model-generated bore sample file:
            file.write(self._dot_samp_file_name_line())  # .samp output file created by mf6bud2smp.bat
            # Enter extrapolation threshold in days (fractional if necessary):
            file.write('1e20\n')
            # Enter name for new bore sample file:
            file.write(f'{os.path.basename(self._files.get(".fsamp", ""))}.out\n')

    def _write_package(self, data):
        """Writes the package file.

        Saves stress period data in external files using the OPEN/CLOSE option.
        """
        self._data = data
        if self._writer_options.dmi_sim_dir:  # Only write component file when not exporting native text
            with open(self._data.filename, 'w') as _:
                pass  # There's currently nothing in the component main file (.pobs)
        else:
            self._copy_pest_files()
            self._write_pest_batch_files()
