"""A writer for STWAVE simulation files."""

# 1. Standard Python modules
import datetime
from io import StringIO
import logging
import shutil

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.api.tree import tree_util
from xms.constraint import read_grid_from_file
from xms.data_objects.parameters import FilterLocation
from xms.guipy.time_format import ISO_DATETIME_FORMAT
from xms.snap.snap_point import SnapPoint

# 4. Local modules
from xms.stwave.data import stwave_consts as const
from xms.stwave.data.simulation_data import SimulationData
import xms.stwave.file_io.dataset_writer


# Adapted from the xml sim file written by ldabell
class SimWriter:
    """A class for writing out an STWAVE simulation file."""

    def __init__(self, query=None, filename='', xms_data=None, template=False):
        """Constructor.

        Args:
            query (Optional[Query]): The XMS interprocess communicator. If not provided, implies a simulation export.
                Should be provided if called from the partial export dialog.
            filename (Optional[str]): The path to the filename to export. If not provided (simulation export), will be
                the simulation name + .sim
            xms_data (dict or None): dictionary of data from xms
            template (bool): If True will export as a CSTORM template
        """
        self.ss = StringIO()
        self.simulation_name = ''
        self.grid_angle = 0.0  # initialized in write_grid_section(), don't use until after called
        self.case_times = []
        self.filename = filename
        self.template = template
        self.xms_data = xms_data if xms_data is not None else {}
        if query:  # Partial export dialog
            self.query = query
            self.sim_export = False
        else:  # Exporting the entire simulation
            self.query = self.xms_data.get('query', Query())
            self.sim_export = self.xms_data.get('sim_export', True)
        self.data = None
        self.grid = None
        self.proj = None
        self.snapper = SnapPoint()
        self.logger = self.xms_data.get('logger', logging.getLogger('xms.stwave'))

    @staticmethod
    def side_text_to_value(cbx_text):
        """Grab the model value based on the text on the Side combo boxes."""
        val = 0
        if cbx_text == const.I_BC_SPECIFIED:
            val = 2
        elif cbx_text == const.I_BC_LATERAL:
            val = 3
        return val

    @staticmethod
    def get_friction_value(cbx_text):
        """Grab the model value based on the text of the Friction combo box."""
        friction = 0
        if cbx_text == const.FRIC_OPT_JONSWAP_CONST:
            friction = 1
        elif cbx_text == const.FRIC_OPT_JONSWAP_DSET:
            friction = 2
        elif cbx_text == const.FRIC_OPT_MANNING_CONST:
            friction = 3
        elif cbx_text == const.FRIC_OPT_MANNING_DSET:
            friction = 4
        return friction

    def write_standard_input(self):
        """
        Writes the Standard input with header to file.

        Returns:
            (int): Plane Mode.
        """
        self.ss.write("#\n# Standard Input Section\n#\n&std_parms\n")
        plane_mode = 1 if self.data.info.attrs["plane"] == const.PLANE_TYPE_FULL else 0
        propagation = 1 if self.data.info.attrs['source_terms'] == const.SOURCE_PROP_ONLY else 0
        self.ss.write(f"  iplane = {plane_mode},\n  iprp = {propagation},\n")
        # Current interaction 0 if not half-plane
        i_cur = 0 if self.data.info.attrs["current_interaction"] == const.OPT_NONE or plane_mode == 1 else 1
        breaking_type = 0
        if self.data.info.attrs['breaking_type'] == const.BREAK_OPT_WRITE:
            breaking_type = 1
        elif self.data.info.attrs['breaking_type'] == const.BREAK_OPT_CALCULATE:
            breaking_type = 2
        self.ss.write(f"  icur = {i_cur},\n  ibreak = {breaking_type},\n  irs = "
                      f"{self.data.info.attrs['rad_stress']},\n")

        # get the number of monitor points
        cov = self.query.item_with_uuid(self.data.info.attrs['monitoring_uuid'])
        try:
            monitor_pts = len(cov.get_points(FilterLocation.PT_LOC_DISJOINT))
        except Exception:
            monitor_pts = '0'
        # get the number of nesting points
        cov = self.query.item_with_uuid(self.data.info.attrs['nesting_uuid'])
        try:
            nest_pts = len(cov.get_points(FilterLocation.PT_LOC_DISJOINT))
        except Exception:
            nest_pts = '0'
        # get the number of output points
        out_pts = '0'
        if self.template:
            out_pts = '%NSTATIONS%'
        elif self.data.info.attrs['output_stations'] == 1:
            cov = self.query.item_with_uuid(self.data.info.attrs['output_stations_uuid'])
            try:
                out_pts = len(cov.get_points(FilterLocation.PT_LOC_DISJOINT))
            except Exception:
                out_pts = '0'
        # get the number of spectral points
        cov = self.query.item_with_uuid(self.data.info.attrs['spectral_uuid'], generic_coverage=True)
        # 0 for single point (no interpolation), 1 for linear interpolation, 2 for morphic interpolation
        try:
            spec_pts = len(cov.m_points)
            if spec_pts > 1:
                ibnd = self.data.info.attrs['interpolation']
                if ibnd == const.INTERP_OPT_LINEAR:
                    ibnd = 1
                elif ibnd == const.INTERP_OPT_MORPHIC:
                    ibnd = 2
                else:
                    ibnd = 0
            else:
                ibnd = '0'
        except Exception:
            ibnd = '0'
        self.ss.write(f"  nselct = {monitor_pts},\n  nnest = {nest_pts},\n  nstations = {out_pts},\n  ibnd = {ibnd},\n")

        ifric = self.get_friction_value(self.data.info.attrs['friction'])
        depth = 0
        if self.data.info.attrs['depth'] == const.DEP_OPT_TRANSIENT:
            depth = 2
        elif self.data.info.attrs['depth'] == const.DEP_OPT_COUPLED:
            depth = 102
        surge = 0 if self.data.info.attrs['surge'] == const.OPT_CONST else 1
        self.ss.write(f"  ifric = {ifric},\n  idep_opt = {depth}"
                      f",\n  isurge = {surge},\n")
        i_wind = 0 if self.data.info.attrs['wind'] == const.OPT_CONST else 1
        i_ice = 0 if self.data.info.attrs['ice'] == const.OPT_NONE else 1
        if self.data.info.attrs['boundary_source'] == const.OPT_NONE:
            bc1 = bc2 = bc3 = bc4 = 0
        else:
            bc1 = self.side_text_to_value(self.data.info.attrs["side1"])
            bc2 = self.side_text_to_value(self.data.info.attrs["side2"])
            bc3 = self.side_text_to_value(self.data.info.attrs["side3"])
            bc4 = self.side_text_to_value(self.data.info.attrs["side4"])
        self.ss.write(f"  iwind = {i_wind},\n  i_bc1 = {bc1},\n  i_bc2 = {bc2},\n  i_bc3 = {bc3},\n  i_bc4 = {bc4}"
                      f",\n  iice = {i_ice},\n")
        if i_ice == 1:  # only write threshold if using ice
            self.ss.write(f"  percent_ice_threshold = {self.data.info.attrs['ice_threshold']:g},\n")

        # get the number of subset points
        if self.data.info.attrs["location_coverage"] == 1:
            cov = self.query.item_with_uuid(self.data.info.attrs['location_coverage_uuid'])
            try:
                subset_pts = len(cov.get_points(FilterLocation.PT_LOC_DISJOINT))
            except Exception:
                subset_pts = 0
            if subset_pts != 0:
                self.ss.write(f"  n_subsample_eng_pts = {subset_pts}\n")
        self.ss.write("/\n")

        return plane_mode

    def write_runtime_section(self, a_plane_mode):
        """
        Writes the runtime parameters with header to file.

        Args:
            a_plane_mode (int): Plane mode.
        """
        if a_plane_mode == 1:
            n_init_iters = self.data.info.attrs["max_init_iters"]
            init_iters_stop = self.data.info.attrs["init_iters_stop_value"]
            init_stop_per = self.data.info.attrs["init_iters_stop_percent"]
            n_final_iters = self.data.info.attrs["max_final_iters"]
            final_iters_stop = self.data.info.attrs["final_iters_stop_value"]
            final_stop_per = self.data.info.attrs["final_iters_stop_percent"]
        else:  # not used for half-plane mode, right out dummies
            n_init_iters = 0
            init_iters_stop = 0
            init_stop_per = 0
            n_final_iters = 0
            final_iters_stop = 0
            final_stop_per = 0
        n_grd_part_i = self.data.info.attrs['processors_i'] if a_plane_mode == 1 else 1  # Always 1 if half-plane
        num_steps = '%NUMSTEPS%' if self.template else f'{len(self.case_times)}'
        self.ss.write(f"#\n# Runtime Parameters Section\n#\n&run_parms\n  idd_spec_type = -2,\n  "
                      f"numsteps = {num_steps},\n  "
                      f"n_grd_part_i = {n_grd_part_i},\n  "
                      f"n_grd_part_j = {self.data.info.attrs['processors_j']},\n  n_init_iters = {n_init_iters},\n  "
                      f"init_iters_stop_value = {init_iters_stop:g},\n  "
                      f"init_iters_stop_percent = {init_stop_per:g},\n  "
                      f"n_final_iters = {n_final_iters},\n  final_iters_stop_value = {final_iters_stop:g},\n  "
                      f"final_iters_stop_percent = {final_stop_per:g}\n/\n"
                      )
        self.xms_data['stwprocs'] = n_grd_part_i * self.data.info.attrs['processors_j']

    def write_grid_section(self):
        """Writes the Spatial Grid Parameters with header to file."""
        proj_sys = self.proj.coordinate_system
        if proj_sys in ['UTM', 'STATEPLANE']:
            # If using a UTM or State-plane projection, include a card for the projection zone.
            proj_sys = f'"{proj_sys}",\n  SPZONE = {self.proj.coordinate_zone}'
        else:
            # Switch no projection text to LOCAL card for STWAVE.
            proj_sys = '"LOCAL"'
        i_sizes = self.grid.locations_x  # These are offsets from the origin
        j_sizes = self.grid.locations_y
        fixed_i_size = i_sizes[1] - i_sizes[0] if len(i_sizes) > 1 else 0.0
        fixed_j_size = j_sizes[1] - j_sizes[0] if len(j_sizes) > 1 else 0.0
        origin = self.grid.origin
        self.ss.write(
            f"#\n# Spatial Grid Parameters Section\n#\n&spatial_grid_parms\n  coord_sys = {proj_sys},\n  "
            f"x0 = {origin[0]},\n  y0 = {origin[1]},\n  azimuth = {self.grid_angle},"
            f"\n  dx = {fixed_i_size},\n  dy = {fixed_j_size},\n  n_cell_i = {len(i_sizes) - 1},"
            f"\n  n_cell_j = {len(j_sizes) - 1}\n/\n"
        )

    def write_input_files_section(self, plane_mode):
        """Writes the Input files with header to file."""
        self.ss.write(f'#\n#\n# Input Files Section\n#\n&input_files\n  DEP = "{self.simulation_name}.dep.in",\n')
        io_types = "  io_type_dep = 1"
        cstorm = self.xms_data.get('cstorm_export', False)
        if self.data.info.attrs["surge"] == const.OPT_DSET or cstorm:
            self.ss.write(f'  SURGE = "{self.simulation_name}.surge.in",\n')
            io_types += ",\n  io_type_surge = 1"
        if self.data.info.attrs["boundary_source"] == const.SPEC_OPT_COV:
            self.ss.write(f'  SPEC = "{self.simulation_name}.eng",\n')
            io_types += ",\n  io_type_spec = 1"
        propagation = 1 if self.data.info.attrs['source_terms'] == const.SOURCE_PROP_ONLY else 0
        self.xms_data['propagation'] = propagation  # for CSTORM
        wind = 0 if self.data.info.attrs['wind'] == 'Constant value' else 1
        if propagation == 0 and (wind == 1 or cstorm):
            self.ss.write(f'  WIND = "{self.simulation_name}.wind.in",\n')
            io_types += ",\n  io_type_wind = 1"
        ifric = self.get_friction_value(self.data.info.attrs['friction'])
        if ifric == 2 or ifric == 4:
            self.ss.write(f'  FRIC = "{self.simulation_name}.fric.in",\n')
            io_types += ",\n  io_type_wind = 1"
        i_cur = 0 if self.data.info.attrs["current_interaction"] == const.OPT_NONE else 1
        if plane_mode == 0 and i_cur == 1:
            self.ss.write(f'  CURR = "{self.simulation_name}.curr.in",\n')
            io_types += ",\n  io_type_curr = 1"
        ice = 0 if self.data.info.attrs['ice'] == const.OPT_NONE else 1
        if ice == 1:
            self.ss.write(f'  ICE = "{self.simulation_name}.ice.in",\n')
            io_types += ",\n  io_type_ice = 1"
        io_types += "\n/\n"
        self.ss.write(io_types)

    def write_output_files_section(self):
        """Writes Output files with header to file."""
        self.ss.write(f'#\n#\n# Output Files Section\n#\n&output_files\n  '
                      f'WAVE = "{self.simulation_name}.wave.out",\n  '
                      f'OBSE = "{self.simulation_name}.obse.out",\n  '
                      f'BREAK = "{self.simulation_name}.break.out",\n  '
                      f'RADS = "{self.simulation_name}.rads.out",\n  '
                      f'SELH = "{self.simulation_name}.selh.out",\n  '
                      f'STATION = "{self.simulation_name}.station.out",\n  '
                      f'NEST = "{self.simulation_name}.nest.out",\n  '
                      f'LOGS = "{self.simulation_name}.log.out",\n  '
                      f'TP = "{self.simulation_name}.Tp.out",\n  '
                      f'XMDF_SPATIAL = "{self.simulation_name}.spatial.out.h5",\n  '
                      f'io_type_tp = 1,\n  io_type_nest = 1,\n  '
                      f'io_type_selh = 1,\n  io_type_rads = 1,\n'
                      )
        if self.data.info.attrs["breaking_type"] == const.BREAK_OPT_NONE:
            self.ss.write("  io_type_break = 0,\n  io_type_obse = 1,\n  io_type_wave = 1,\n  io_type_station = 1\n/\n")
        else:
            self.ss.write("  io_type_break = 1,\n  io_type_obse = 1,\n  io_type_wave = 1,\n  io_type_station = 1\n/\n")

    def write_points_sections(self):
        """Writes Selected, Nest, and Station Points with headers to file."""
        # snapped monitor points
        self.snapper.set_grid(self.grid, True)

        self.ss.write("#\n# Selected Points Section\n#\n@select_pts\n")
        cov = self.query.item_with_uuid(self.data.info.attrs['monitoring_uuid'])
        if cov:
            points = cov.get_points(FilterLocation.PT_LOC_DISJOINT)
            snapped_points = self.snapper.get_snapped_points(points)
            for pt_idx, cell_idx in enumerate(snapped_points['id']):
                i, j = self.grid.get_cell_ij_from_index(cell_idx)
                self.ss.write(f"  iout({pt_idx + 1}) = {i}, jout({pt_idx + 1}) = {j},\n")
        self.ss.write("/\n")

        # snapped nesting points
        self.ss.write("#\n# Nest Points Section\n#\n@nest_pts\n")
        cov = self.query.item_with_uuid(self.data.info.attrs['nesting_uuid'])
        if cov:
            points = cov.get_points(FilterLocation.PT_LOC_DISJOINT)
            snapped_points = self.snapper.get_snapped_points(points)
            for pt_idx, cell_idx in enumerate(snapped_points['id']):
                i, j = self.grid.get_cell_ij_from_index(cell_idx)
                self.ss.write(f"  inest({pt_idx + 1}) = {i}, jnest({pt_idx + 1}) = {j},\n")
        self.ss.write("/\n")

        # station points
        self.ss.write("#\n# Station Points Section\n#\n@station_locations\n")
        cov = self.query.item_with_uuid(self.data.info.attrs['output_stations_uuid'])
        if cov:
            points = cov.get_points(FilterLocation.PT_LOC_DISJOINT)
            for idx, snap_pt in enumerate(points):
                self.ss.write(f"  stat_xcoor({idx + 1}) = {snap_pt.x}, stat_ycoor({idx + 1}) = {snap_pt.y},\n")
        self.ss.write("/\n")

    def write_const_wind_section(self):
        """Writes Spatially Constant Wind data with header to file."""
        self.ss.write("#\n# Spatially Constant Winds Section\n#\n@const_wind\n")
        if self.data.info.attrs["source_terms"] == const.SOURCE_PROP_ONLY:  # fill in dummy values if prop only
            for i in range(len(self.case_times)):
                self.ss.write(f"  umag_const_in({i + 1}) = 0.0, udir_const_in({i + 1}) = 0.0,\n")
        elif self.data.info.attrs["wind"] == const.OPT_CONST:
            ang_conv = self.data.info.attrs['angle_convention']
            wind_dirs = self.data.case_times['Wind Direction']
            wind_mags = self.data.case_times['Wind Magnitude']
            for idx, (mag, wdir) in enumerate(zip(wind_mags, wind_dirs)):
                py_dir = wdir
                if ang_conv == const.ANG_CONV_CARTESIAN:
                    shore_normal_ang = xms.stwave.file_io.dataset_writer.cart_to_shore_normal(py_dir,
                                                                                              self.grid_angle)
                elif ang_conv == const.ANG_CONV_METEOROLOGIC:
                    shore_normal_ang = xms.stwave.file_io.dataset_writer.meteor_to_shore_normal(py_dir,
                                                                                                self.grid_angle)
                elif ang_conv == const.ANG_CONV_OCEANOGRAPHIC:
                    shore_normal_ang = xms.stwave.file_io.dataset_writer.ocean_to_shore_normal(py_dir,
                                                                                               self.grid_angle)
                else:  # already in shore normal
                    # We could convert to shore normal to enforce -180 to 180 convention.
                    # shore_normal_ang = xms.stwave.file_io.dataset_writer.cart_to_shore_normal(py_dir, 0)
                    shore_normal_ang = py_dir
                self.ss.write(f"  umag_const_in({idx + 1}) = {float(mag)}, udir_const_in({idx + 1}) = "
                              f"{float(shore_normal_ang)},\n")
        self.ss.write("/\n")

    def write_const_surge_section(self):
        """Writes Spatial Water Level Adjustment data with header to file."""
        self.ss.write("#\n# Spatial Water Level Adjustment Section\n#\n@const_surge\n")
        if self.data.info.attrs["surge"] == const.OPT_CONST:
            levels = self.data.case_times['Water Level']
            for idx, level in enumerate(levels):
                self.ss.write(f"  dadd_const_in({idx + 1}) = {float(level)},\n")
        self.ss.write("/\n")

    def write_subsample_locations(self):
        """Writes coordinate locations of sub-sample with header to file."""
        self.ss.write("#\n# Coordinate Locations of Sub-Sample Spectral ENG Boundary Points\n#\n@subsample_speceng\n")
        cov = self.query.item_with_uuid(self.data.info.attrs['location_coverage_uuid'])
        if cov:
            sub_pts = cov.get_points(FilterLocation.PT_LOC_DISJOINT)
            for idx, pt in enumerate(sub_pts):
                self.ss.write(f"  xloc_subsample({idx + 1}) = {pt.x}, yloc_subsample({idx + 1}) = {pt.y},\n")
        self.ss.write("/\n")

    def write(self):
        """Writes all data to file."""
        # Get the simulation tree item
        self.logger.info('Retrieving simulation data from SMS...')
        if self.sim_export:  # Exporting the entire simulation, Context is at simulation level
            sim_uuid = self.xms_data.get('sim_uuid', self.query.current_item_uuid())
            sim_comp = self.xms_data.get('sim_comp', None)
            if sim_comp is None:
                sim_comp = self.query.item_with_uuid(sim_uuid, model_name='STWAVE', unique_name='Sim_Component')
        else:  # Partial export dialog, Context is at simulation component level
            sim_uuid = self.query.parent_item_uuid()
            sim_comp = self.query.current_item()
        sim_item = tree_util.find_tree_node_by_uuid(self.query.project_tree, sim_uuid)
        # Get the simulation's hidden component data.
        self.data = SimulationData(sim_comp.main_file)
        # Get the grid
        self.logger.info('Retrieving domain grid from SMS...')
        grid_uuid = self.data.info.attrs['grid_uuid']
        do_grid = self.query.item_with_uuid(grid_uuid)
        if do_grid is None:
            self.logger.error('Unable to find STWAVE grid.')
            return None
        self.grid = read_grid_from_file(do_grid.cogrid_file)
        self.grid_angle = self.grid.angle  # store grid angle so it can be used to translate angles
        self.proj = do_grid.native_projection
        self.xms_data['grid_prj'] = self.proj

        # get simulation name - will need it to build filenames
        self.simulation_name = sim_item.name

        # get case times - will need them in various places
        self.case_times = self.data.times_in_seconds()

        # file header
        self.ss.write("# STWAVE_SIM_FILE\n# Written from SMS v13.2\n#\n"
                      "##############################################\n")

        # standard input section
        self.logger.info('Writing standard input section...')
        plane_mode = self.write_standard_input()

        # runtime parameters section
        self.logger.info('Writing runtime section...')
        self.write_runtime_section(plane_mode)

        # grid parameters section
        self.logger.info('Writing grid parameters section...')
        self.write_grid_section()

        # input files section
        self.logger.info('Writing input files section...')
        self.write_input_files_section(plane_mode)

        # output files section
        self.logger.info('Writing output files section...')
        self.write_output_files_section()

        # time parameters section
        self.logger.info('Writing time section...')
        self.ss.write('#\n# Time Parameter Section\n#\n&time_parms\n  i_time_inc_units = "mm"\n/\n')

        # constant boundary spectra section
        self.logger.info('Writing boundary spectra section...')
        self.ss.write("#\n# Constant Boundary Spectra Section\n#\n&const_spec\n")
        if self.data.info.attrs["boundary_source"] == const.OPT_NONE:
            self.ss.write(
                f"  nfreq = {self.data.info.attrs['num_frequencies']},\n  na = 72,\n  "
                f"f0 = {self.data.info.attrs['min_frequency']},\n  "
                f"df_const = {self.data.info.attrs['delta_frequency']}\n/\n"
            )
        else:
            self.ss.write("/\n")

        # analytic depth profile section - not used by sms
        self.logger.info('Writing analytic depth profile section...')
        self.ss.write("#\n# Analytic Depth Profile Section\n#\n&depth_func\n/\n")

        # constant bottom friction section
        self.logger.info('Writing bottom friction section...')
        friction_option = self.data.info.attrs["friction"]
        if friction_option == const.FRIC_OPT_JONSWAP_CONST or friction_option == const.FRIC_OPT_MANNING_CONST:
            if friction_option == const.FRIC_OPT_JONSWAP_CONST:
                const_fric = self.data.info.attrs["JONSWAP"]
            else:
                const_fric = self.data.info.attrs["manning"]
            self.ss.write(f"#\n# Constant Bottom Friction Section\n#\n&const_fric\n  cf_const = {const_fric},\n/\n")

        # snap idds section
        self.logger.info('Writing bottom snap IDDs section...')
        self.ss.write("#\n# Snap IDD's Section\n#\n@snap_idds\n")
        self.xms_data['st_times'] = []  # used by CSTORM
        reftime = datetime.datetime.strptime(self.data.info.attrs['reftime'], ISO_DATETIME_FORMAT)
        for idx, time in enumerate(self.case_times):
            try:
                dt = reftime + datetime.timedelta(seconds=time)
            except OverflowError:
                raise OverflowError('Error converting case times to date time value.  Check case time values.')
            out_str = f"  idds({idx + 1}) = {xms.stwave.file_io.dataset_writer.get_time_string(dt)},\n"
            self.ss.write(out_str)
            self.xms_data['st_times'].append(out_str)
        self.ss.write("/\n")

        # monitor, nest, and station point sections
        self.logger.info('Writing monitor, nest, and station point section...')
        self.write_points_sections()

        # const wind section
        self.logger.info('Writing wind section...')
        self.write_const_wind_section()

        # const surge section
        self.logger.info('Writing surge section...')
        self.write_const_surge_section()

        # sub-sample spectral boundaries section
        self.logger.info('Writing sub-sample section...')
        self.write_subsample_locations()

        # TMA boundary spectra section - not used by sms
        self.logger.info('Writing TMA boundary spectra section...')
        self.ss.write("#\n# TMA Boundary Spectra Section\n#\n@const_tma_spec\n/\n")

        # flush to file
        filename = self.filename if self.filename else f'{self.simulation_name}.sim'
        out = open(filename, 'w')
        self.ss.seek(0)
        shutil.copyfileobj(self.ss, out)
        out.close()

        return do_grid
