"""FUNWAVE control file writer."""
__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import logging
import os
from pathlib import Path

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.constraint import RectilinearGrid2d
from xms.data_objects.parameters import FilterLocation
from xms.datasets.dataset_reader import DatasetReader
from xms.snap import SnapPoint

# 4. Local modules
from xms.funwave.dmi.xms_data import XmsData
from xms.funwave.file_io.input_txt_writer import write_input_txt_file


class ControlWriter:
    """Writer class for the FUNWAVE control file."""

    def __init__(self, xms_data: XmsData):
        """Constructor.

        Args:
            xms_data (:obj:`XmsData`): Simulation data retrieved from SMS
        """
        self._xms_data = xms_data

    def write(self):
        """Write the FUNWAVE control file."""
        sim_folder = Path(os.getcwd())
        os.makedirs(os.path.join(sim_folder, 'output'), exist_ok=True)
        values = self._xms_data.sim_data.values_to_dict()
        grid = self._xms_data.cogrid
        valid_grid = False
        if grid is not None:
            dx_dy = grid.check_all_cells_are_rectangular_in_xy()
            if dx_dy is not None:
                values['DX'] = dx_dy[0]
                values['DY'] = dx_dy[1]
            if dx_dy is not None and isinstance(grid, RectilinearGrid2d):
                values['Mglob'] = len(grid.locations_x) - 1
                values['Nglob'] = len(grid.locations_y) - 1
                valid_grid = True

        values['NumberStations'] = 0
        if valid_grid:
            n = values['Nglob']  # y and J direction
            m = values['Mglob']  # x and I direction

            # Update values based on coverages and write station file
            coverages = self._xms_data.retrieve_output_station_data()
            if len(coverages) > 0:
                # Assume it to be a single Output Stations coverage
                output_stations_cov = coverages[0]
                filename = f'{output_stations_cov.name}.txt'
                # Check the filename
                filename = filename.replace(':', '_')
                filename = filename.replace(' - ', '-')
                filename = filename.replace(' ', '_')
                filename = filename.replace('__', '_')
                values['STATIONS_FILE'] = filename
                points = output_stations_cov.get_points(FilterLocation.LOC_ALL)
                values['NumberStations'] = len(points)
                write_stations_file(sim_folder, filename, points, grid, m)

            write_elevation_dataset_file(sim_folder / 'input', 'depth_file', grid.cell_elevations, n, m)

            dataset_uuids = self._xms_data.sim_data.get_dataset_uuids()
            datasets = self.get_dataset_readers(dataset_uuids)
            write_dataset_files(sim_folder / 'input', datasets, n, m)

        else:  # Grid invalid
            logger = logging.Logger('xms.funwave')
            logger.error('Simulation requires a valid 2D grid.')

        if values['OUTPUT_RES'] != 1:
            logger = logging.Logger('xms.funwave')
            logger.warning('The resulting datasets will not be supported in SMS. Change the value of "Output data '
                           'resolution (OUTPUT_RES)" to be compatible with SMS.')

        write_input_txt_file(sim_folder / 'input.funwave', values, valid_grid)

    def get_dataset_readers(self, dataset_uuids) -> dict[str, DatasetReader]:
        """Get the datasets that need to be written.

        Returns:
            (:obj:`dict`): Dictionary of dataset name to DatasetReader.
        """
        datasets = {}
        for dataset_name in dataset_uuids:
            if dataset_uuids[dataset_name]:
                dataset_reader = self._xms_data.get_dataset_reader(dataset_uuids[dataset_name])
                if dataset_reader is not None:
                    datasets[dataset_name.lower()] = dataset_reader
        return datasets


def write_stations_file(sim_folder, filename, points, co_grid, row_cell_count):
    """Write the output stations to a txt file.

    Args:
        sim_folder (:obj:`Path`): path to the simulation write folder
        filename (:obj:`string`): Filename of the file to create
        points (:obj:`list of Points`): points from the output stations coverage
        co_grid (:obj:`Grid`): The grid
        row_cell_count (:obj:`int`): The number of cells in a row
    """
    # write the file and data
    snapper = SnapPoint()
    snapper.set_grid(co_grid, True)
    cells = snapper.get_snapped_points(points)

    file = open(sim_folder / filename, "w")
    for index, cell in enumerate(cells['id']):
        i_x = -1
        j_y = -1
        if cell >= 0:
            j_y = cell // row_cell_count
            i_x = cell - j_y * row_cell_count
            j_y += 1
            i_x += 1
        else:
            logger = logging.Logger('xms.funwave')
            logger.warn(f'Station point {index} is not within the grid bounds.')
        file.write(f'{i_x}   {j_y}\n')
    file.close()


def write_dataset_files(output_folder: Path, datasets: dict[str, DatasetReader], n: int, m: int):
    """Write dataset text files.

    Args:
        output_folder(:obj:`Path`): The output folder.
        datasets(:obj:`dict[str, DatasetReader]`): Dictionary of dataset names to DatasetReader.
        n(:obj:`int`): The number of grid rows.
        m(:obj:`int`): The number of grid columns.
    """
    os.makedirs(output_folder, exist_ok=True)
    for name, dataset in datasets.items():
        data = np.array(dataset.values[:])
        data = data.reshape((n, m))
        np.savetxt(str(output_folder / f'{name}.txt'), data)


def write_elevation_dataset_file(output_folder: Path, name: str, dataset_values, n: int, m: int):
    """Write dataset text files.

    Args:
        output_folder(:obj:`Path`): The output folder.
        name (str): The name of the dataset.
        dataset_values(:obj:`list`): The dataset values.
        n(:obj:`int`): The number of grid rows.
        m(:obj:`int`): The number of grid columns.
    """
    os.makedirs(output_folder, exist_ok=True)

    data = np.array(dataset_values)
    if len(data) == 0:
        return
    flipped_values = data.copy()
    flipped_values = -flipped_values

    flipped_values = flipped_values.reshape((n, m))
    np.savetxt(str(output_folder / f'{name}.txt'), flipped_values)
