"""ProjectWriter class."""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules
from datetime import datetime
import logging
from pathlib import Path

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.api.tree import TreeNode
from xms.coverage.xy import xy_util
from xms.coverage.xy.xy_series import XySeries
from xms.coverage.xy.xy_util import XySeriesDict
from xms.gmi.data.generic_model import Group

# 4. Local modules
from xms.gssha.components import dmi_util
from xms.gssha.components.sim_component import SimComponent
from xms.gssha.data import bc_generic_model, bc_util, data_util, sim_generic_model
from xms.gssha.data.bc_generic_model import olf_type_params
from xms.gssha.data.bc_util import BcData
from xms.gssha.data.sim_generic_model import InfilType, PrecipitationType, Routing
from xms.gssha.file_io import (
    cif_file_writer, cmt_file_writer, gag_file_writer, gst_file_writer, ihg_file_writer, ihl_file_writer, io_util,
    smt_file_writer, xys_file_writer
)
from xms.gssha.file_io.block import Block

# Constants
_ALIGN = 25  # Values next to cards are aligned starting on this column


def write(main_file: str | Path, query: Query, out_dir: Path, sim_node: TreeNode) -> Path | None:
    """Writes the .gssha and all related files for the simulation and returns the .gssha file_path.

    Args:
        main_file (str | Path): Simulation component main file.
        query (Query): Object for communicating with XMS
        out_dir (Path): Path to the output directory.
        sim_node (TreeNode): Simulation tree node.
    """
    project_writer = ProjectWriter(main_file, query, out_dir, sim_node)
    return project_writer.write()


class ProjectWriter:
    """Writes all the files for the simulation."""
    def __init__(self, main_file: str | Path, query: Query, out_dir: Path, sim_node: TreeNode) -> None:
        """Initializes the class.

        Args:
            main_file (str | Path): Simulation component main file.
            query (Query): Object for communicating with XMS
            out_dir (Path): Path to the output directory.
            sim_node (TreeNode): Simulation tree node.
        """
        super().__init__()
        self._main_file_path: Path = Path(main_file)
        self._query = query
        self._out_dir = out_dir
        self._sim_node = sim_node

        self._log = logging.getLogger('xms.gssha')
        self._gssha_file_path: Path = Path()
        self._file = None  # .gssha file object
        self._sim_comp = None  # SimComponent
        self._gm = None  # Generic model
        self._align = 25  # Values aligned at column 26, after the card (e.g. 'CARD                     value')
        self._co_grid = None
        self._ugrid = None  # Result of 'co_grid.ugrid' so we do it only once, because it's costly
        self._elevation_dataset = None
        self._bc_cov = None
        self._bc_comp = None
        self._stream_data: BcData = None  # Data on the stream bcs
        self._olf_data: BcData = None  # Data on the overland flow bcs
        self._xy_dict: XySeriesDict = {}  # With "new" xy series ids (after combining to get unique xy series)
        self._routing = False
        self._soil_erosion = False
        self._transport = False
        self._start_date_time = None  # Start date/time defined in Precipitation group

    def write(self) -> Path | None:
        """Writes the gssha file and returns the gssha filename."""
        self._log.info('Writing .gssha file...')
        try:
            self._set_gssha_filename()
            self._read_sim_component_data()
            self._read_grid()
            self._find_elevation_dataset()
            self._find_coverage()
            self._get_bc_data()
            self._compute_link_numbers()
            self._collect_xy_series()
            self._get_start_date_time()
            with open(self._gssha_file_path, 'w') as self._file:
                self._file.write('GSSHAPROJECT\n')
                self._write_general()
                self._write_overland_flow()
                self._write_infiltration()
                self._write_channel_routing()
                self._write_output()
                self._write_precipitation()
                self._log.info('Writing finished')

        except Exception as error:
            self._log.error(str(error))

        return self._gssha_file_path

    def _set_gssha_filename(self) -> None:
        """Sets up the gssha filename."""
        if not self._sim_node:
            raise RuntimeError('No simulation node found.')

        self._gssha_file_path = self._out_dir / f'{self._sim_node.name}.gssha'

    def _read_sim_component_data(self) -> None:
        """Reads the data from the sim component."""
        self._sim_comp = SimComponent(str(self._main_file_path))
        self._gm = sim_generic_model.create(default_values=False)
        self._gm.global_parameters.restore_values(self._sim_comp.data.global_values)
        self._gm.model_parameters.restore_values(self._sim_comp.data.model_values)

    def _read_grid(self) -> None:
        """Reads the grid."""
        self._co_grid = dmi_util.read_co_grid(self._query, self._sim_node)
        self._ugrid = self._co_grid.ugrid  # We only want to do this once

    def _find_elevation_dataset(self) -> None:
        """Saves the ELEVATION dataset because we use it in a couple of places."""
        group = self._gm.global_parameters.group('overland_flow')
        self._elevation_dataset, _ = dmi_util.get_dataset(group, 'ELEVATION', self._query)
        if not self._elevation_dataset:
            raise RuntimeError('ELEVATION dataset not set, or could not be found')  # pragma no cover

    def _find_coverage(self) -> None:
        """Finds and stores the linked coverages."""
        self._bc_cov, self._bc_comp = dmi_util.find_bc_coverage_and_component(self._sim_node, self._query)

    def _get_bc_data(self) -> None:
        """Gets data about the streams that we will need later.

        Although this gets the current arc links data, arc links are recalculated later.
        """
        if self._bc_cov is not None and self._bc_comp:
            self._stream_data = bc_util.get_bc_data(self._bc_cov, self._bc_comp, {'channel'})
            self._olf_data = bc_util.get_bc_data(self._bc_cov, self._bc_comp, {'overland_flow'})

    def _compute_link_numbers(self) -> None:
        """Compute the stream link numbers."""
        self._log.info('Computing link numbers...')
        if self._stream_data and self._stream_data.feature_bc and self._bc_comp:
            rv = bc_util.compute_link_numbers(self._stream_data)  # rv will be ArcLinks or an error string
            if isinstance(rv, str):
                raise RuntimeError(rv)
            if len(rv) != len(self._stream_data.feature_bc):
                raise RuntimeError('Link number assignment failed. Check that arcs are pointing downstream.')

            self._stream_data.arc_links = rv  # Add it to the stream data for convenience
            self._bc_comp.data.set_arc_links(rv)
            self._bc_comp.data.commit()

    def _collect_xy_series(self) -> None:
        """Collects the bc xy series into a dict for later use."""
        if not self._olf_data:
            return

        self._log.info('Collecting XY series...')
        ts_id = 1  # Used to name the time series by number
        for _feature, group in self._olf_data.feature_bc.items():
            olf_type_param = group.parameter('overland_flow_type')
            # Only need to consider the variable types
            if olf_type_param.value in bc_generic_model.get_olf_types('variable'):
                param_name = olf_type_params[olf_type_param.value]
                x, y = group.parameter(param_name).value
                xy_name = f'ts_{ts_id}'  # This is how WMS names them
                ts_id += 1
                xy_series = XySeries(x, y, name=xy_name)
                xy_util.add_or_match(xy_series, self._xy_dict)  # This also sets the XySeries id

    def _get_start_date_time(self) -> None:
        """Gets the start date/time which we need in multiple places."""
        group = self._gm.global_parameters.group('precipitation')
        date_time_str = group.parameter('start_date_time').value
        self._start_date_time = datetime.fromisoformat(date_time_str)

    def _write_general(self) -> None:
        """Writes the general stuff."""
        with Block(self._file, self._log, 'General Simulation Data', _ALIGN) as block:
            group = self._gm.global_parameters.group('general')
            watershed_mask_file_path = self._gssha_file_path.with_suffix('.msk')
            io_util.write_mask_file(self._co_grid, self._ugrid, watershed_mask_file_path)
            # Seems like we don't need to write the PROJECT_PATH card
            # block.write(group=None, name='PROJECT_PATH', value=f'"{str(self._gssha_file_path.parent)}"')
            block.write(group=None, name='WATERSHED_MASK', value=f'"{watershed_mask_file_path.name}"')
            block.write(group=None, name='NON_ORTHO_CHANNELS')
            block.write(group, 'TOT_TIME')
            block.write(group, 'TIMESTEP')
            block.write(group, 'OUTSLOPE')
            block.write(group=None, name='GRIDSIZE', value=self._cell_size())
            rows, columns = self._num_rows_columns()
            block.write(group=None, name='ROWS', value=rows)
            block.write(group=None, name='COLS', value=columns)
            outrow, outcol = self._get_outlet_row_col()
            block.write(group=None, name='OUTROW', value=str(outrow))
            block.write(group=None, name='OUTCOL', value=str(outcol))

    def _num_rows_columns(self) -> tuple[int, int]:
        """Returns the number of rows and columns."""
        return len(self._co_grid.locations_y) - 1, len(self._co_grid.locations_x) - 1

    def _write_overland_flow(self) -> None:
        """Writes the overland flow stuff."""
        with Block(self._file, self._log, 'Overland Flow', _ALIGN) as block:
            group = self._gm.global_parameters.group('overland_flow')
            block.write(group, 'OVERTYPE')

            # The following are not standard GSSHA cards, but they allow us to serialize/restore these settings
            block.write(group, 'ROUGH_EXP')
            block.write(group, 'INTERCEPTION')
            block.write(group, 'RETENTION')
            block.write(group, 'AREA_REDUCTION')

            # ELEVATION
            elev_file_path = self._gssha_file_path.with_suffix('.ele')
            io_util.write_grass_file(self._co_grid, self._ugrid, self._elevation_dataset, elev_file_path, ints=False)
            block.write(group=None, name='ELEVATION', value=f'"{elev_file_path.name}"')

            # Overland flow bcs
            if self._overland_flow_bcs_exist():
                block.write(None, 'OV_BOUNDARY')
                self._write_time_series(block)  # TIME_SERIES_FILE

            self._write_mapping_table(block)  # MAPPING_TABLE

    def _overland_flow_bcs_exist(self) -> bool:
        """Returns True if any overland flow bcs exist."""
        return bool(self._olf_data and self._olf_data.feature_bc)

    def _write_infiltration(self) -> None:
        """Writes the infiltration stuff."""
        with Block(self._file, self._log, 'Infiltration') as block:
            group = self._gm.global_parameters.group('infiltration')
            p = group.parameter('infil_type')
            infiltration_type = p.value
            if infiltration_type == InfilType.NO_INFILTRATION:  # 'No infiltration'
                return
            elif infiltration_type == InfilType.INF_REDIST:
                block.write(group=None, name='INF_REDIST')
            elif infiltration_type == InfilType.INF_LAYERED_SOIL:
                block.write(group=None, name='INF_LAYERED_SOIL')
            elif infiltration_type == InfilType.INF_RICHARDS:
                block.write(group=None, name='INF_RICHARDS')

            if infiltration_type == InfilType.INF_RICHARDS:  # "Richard's infiltration (INF_RICHARDS)"
                block.write(group, 'RICHARDS_WEIGHT')
                block.write(group, 'RICHARDS_DTHETA_MAX')
                block.write(group, 'RICHARDS_C_OPTION')
                block.write(group, 'RICHARDS_K_OPTION')
                block.write(group, 'RICHARDS_UPPER_OPTION')
                block.write(group, 'RICHARDS_ITER_MAX')
                block.write(group, 'MAX_NUMBER_CELLS')

            if group.parameter('SOIL_MOIST_DEPTH_CHK').value:
                block.write(group, 'SOIL_MOIST_DEPTH')
            if infiltration_type == InfilType.INF_REDIST and group.parameter('TOP_LAYER_DEPTH_CHK').value:
                block.write(group, 'TOP_LAYER_DEPTH')

    def _write_channel_routing(self) -> None:
        """Writes the channel routing stuff."""
        with Block(self._file, self._log, 'Channel Routing', _ALIGN) as block:
            group = self._gm.global_parameters.group('channel_routing')
            p = group.parameter('routing')
            if p.value == Routing.NO_ROUTING:
                return

            block.write(group=None, name='DIFFUSIVE_WAVE')
            self._routing = True
            if self._stream_data and self._stream_data.feature_bc and self._bc_comp:
                self._write_channel_input(block)  # CHANNEL_INPUT
                self._write_stream_cell(block)  # STREAM_CELL
                self._write_in_hyd_location(block)  # IN_HYD_LOCATION
                self._write_stream_mapping_table(block)  # ST_MAPPING_TABLE
                self._write_chan_point_input(block)  # CHAN_POINT_INPUT
            block.write(group, 'OVERBANK_FLOW')
            block.write(group, 'OVERLAND_BACKWATER')
            if group.parameter('HEAD_BOUND').value:
                block.write(group, 'HEAD_BOUND')
                block.write(group, 'BOUND_DEPTH')
            if group.parameter('write_chan_hotstart_chk').value:
                block.write(group, 'WRITE_CHAN_HOTSTART', quotes=True)
            if group.parameter('read_chan_hotstart_chk').value:
                block.write(group, 'READ_CHAN_HOTSTART', quotes=True)

        if group.parameter('SOIL_EROSION').value:
            with Block(self._file, self._log, 'Soil Erosion', _ALIGN) as erosion_block:
                self._soil_erosion = True
                erosion_block.write(group, 'SOIL_EROSION', 2)  # Hardwired to Kilinc Richardson for now
                erosion_block.write(group, 'SED_POROSITY')
                erosion_block.write(group, 'WATER_TEMP')
                erosion_block.write(group, 'SAND_SIZE')

        if group.parameter('COMPUTE_CONTAMINANT_TRANSPORT').value:
            with Block(self._file, self._log, 'Transport', _ALIGN) as transport_block:
                self._transport = True
                transport_block.write(None, 'OV_CON_TRANS')
                transport_block.write(None, 'SOIL_CONTAM')
                transport_block.write(None, 'CHAN_CON_TRANS')
                transport_block.write(group, 'CHAN_DECAY_COEF')
                transport_block.write(group, 'CHAN_DISP_COEF')
                transport_block.write(group, 'INIT_CHAN_CONC')
                # MIXING_LAYER_DEPTH optional, and not yet supported

    def _write_output(self) -> None:
        """Writes the output stuff."""
        self._write_output_general()
        self._write_output_gridded()
        self._write_output_link_node()

    def _write_output_general(self) -> None:
        """Writes the general output stuff."""
        with Block(self._file, self._log, 'Output', _ALIGN) as block:
            group = self._gm.global_parameters.group('output_general')

            # Write frequency
            block.write(group, 'MAP_FREQ')

            # Hydrograph
            block.write(group, 'HYD_FREQ')
            if group.parameter('output_units').value == 0:
                block.write(group=None, name='METRIC')
            else:
                block.write(group=None, name='QOUT_CFS')

            # Other
            block.write(group, 'QUIET')
            block.write(group, 'SUPER_QUIET')
            block.write(group, 'STRICT_JULIAN_DATE')
            block.write(group, 'ALL_CONTAM_OUTPUT')
            # 0 Arc/Info ASCII maps. 1 WMS maps, ASCII (default). 2 WMS maps, binary. 4 XMDF maps (input and output).
            block.write(group=None, name='MAP_TYPE', value='1')

            # Output files
            block.write(group=None, name='SUMMARY', value=f'"{self._gssha_file_path.with_suffix(".sum").name}"')
            block.write(group=None, name='OUTLET_HYDRO', value=f'"{self._gssha_file_path.with_suffix(".otl").name}"')

            # Optional output
            if self._routing:
                block.write(None, 'OUT_HYD_LOCATION', f'"{self._gssha_file_path.stem}.ohl"')
                if self._soil_erosion:
                    block.write(None, 'OUTLET_SED_FLUX', f'"{self._gssha_file_path.stem}.sed"')
                    block.write(None, 'OUTLET_SED_TSS', f'"{self._gssha_file_path.stem}.oss"')
                if self._transport:
                    block.write(None, 'OUT_MASS_LOCATION', f'"{self._gssha_file_path.stem}.oml"')
                    block.write(None, 'OUT_CON_LOCATION', f'"{self._gssha_file_path.stem}.ocl"')

    def _write_output_gridded(self) -> None:
        """Writes the gridded datasets output stuff."""
        with Block(self._file, self._log, 'Output - gridded datasets', _ALIGN) as block:
            group = self._gm.global_parameters.group('output_gridded')

            # Gridded datasets
            self._write_output_file(group, 'DIS_RAIN', '.drn', block)
            self._write_output_file(group, 'DEPTH', '.dep', block)

            infil_group = self._gm.global_parameters.group('infiltration')
            if infil_group.parameter('infil_type').value != InfilType.NO_INFILTRATION:
                self._write_output_file(group, 'INF_DEPTH', '.inf', block)

            self._write_output_file(group, 'RATE_OF_INFIL', '.fav', block)
            self._write_output_file(group, 'SURF_MOIST', '.sur', block)
            self._write_output_file(group, 'FLOOD_GRID', '.gfl', block)

    def _write_output_link_node(self) -> None:
        """Writes the link/node datasets output stuff."""
        with Block(self._file, self._log, 'Output - link/node datasets', _ALIGN) as block:
            group = self._gm.global_parameters.group('output_link_node')

            # Link / Node data sets
            self._write_output_file(group, 'CHAN_DEPTH', '.cdp', block)
            self._write_output_file(group, 'CHAN_DISCHARGE', '.cdq', block)
            self._write_output_file(group, 'CHAN_VELOCITY', '.vel', block)
            self._write_output_file(group, 'CHAN_SED_FLUX', '.cfx', block)
            self._write_output_file(group, 'FLOOD_STREAM', '.cfl', block)
            self._write_output_file(group, 'CHAN_STAGE', '.wse', block)
            self._write_output_file(group, 'PIPE_FLOW', '.pfl', block)
            self._write_output_file(group, 'PIPE_HEAD', '.pnd', block)
            self._write_output_file(group, 'PIPE_TILE', '.pio', block)
            self._write_output_file(group, 'SUPERLINK_JUNC_FLOW', '.sjf', block)
            self._write_output_file(group, 'SUPERLINK_NODE_FLOW', '.snf', block)

    def _write_precipitation(self) -> None:
        """Writes the precipitation stuff."""
        with Block(self._file, self._log, 'Precipitation', _ALIGN) as block:
            group = self._gm.global_parameters.group('precipitation')
            precipitation_type = group.parameter('precipitation_type')
            if precipitation_type.value == PrecipitationType.PRECIP_UNIF:  # Uniform
                block.write(group=None, name='PRECIP_UNIF')
                block.write(group, 'RAIN_INTENSITY')
                block.write(group, 'RAIN_DURATION')
                block.write(group=None, name='START_DATE', value=self._start_date_time.strftime('%Y %m %d'))
                block.write(group=None, name='START_TIME', value=self._start_date_time.strftime('%H %M'))
            else:  # Hyetograph
                self._write_precip_file(block)

    def _write_output_file(self, group: Group, parameter_name: str, extension: str, block: Block) -> None:
        """Writes the card and the output file if the option is on."""
        if group.parameter(parameter_name).value:
            block.write(group, name=parameter_name, value=f'"{self._gssha_file_path.with_suffix(extension).name}"')

    def _write_mapping_table(self, block: Block) -> None:
        """Writes the MAPPING_TABLE command."""
        file_path = cmt_file_writer.write(
            self._gssha_file_path, self._gm, self._co_grid, self._ugrid, self._olf_data, self._xy_dict, self._query
        )
        if file_path:
            block.write(group=None, name='MAPPING_TABLE', value=f'"{file_path.name}"')

    def _write_stream_mapping_table(self, block: Block) -> None:
        """Writes the ST_MAPPING_TABLE command."""
        file_path = smt_file_writer.write(self._gssha_file_path)
        if file_path:
            block.write(group=None, name='ST_MAPPING_TABLE', value=f'"{file_path.name}"')

    def _write_time_series(self, block: Block) -> None:
        """Writes the TIME_SERIES_FILE command."""
        file_path = xys_file_writer.write(self._gssha_file_path, self._xy_dict, self._start_date_time)
        if file_path:
            block.write(group=None, name='TIME_SERIES_FILE', value=f'"{file_path.name}"')

    def _write_channel_input(self, block: Block) -> None:
        """Writes the CHANNEL_INPUT command."""
        file_path = cif_file_writer.write(self._gssha_file_path, self._stream_data, self._bc_comp.data)
        if file_path:
            block.write(group=None, name='CHANNEL_INPUT', value=f'"{file_path.name}"')

    def _write_stream_cell(self, block: Block) -> None:
        """Writes the STREAM_CELL command."""
        file_path = gst_file_writer.write(self._gssha_file_path, self._stream_data, self._co_grid, self._ugrid)
        if file_path:
            block.write(group=None, name='STREAM_CELL', value=f'"{file_path.name}"')

    def _write_in_hyd_location(self, block: Block) -> None:
        """Writes the IN_HYD_LOCATION command."""
        file_path = ihl_file_writer.write(self._gssha_file_path, self._stream_data)
        if file_path:
            block.write(group=None, name='IN_HYD_LOCATION', value=f'"{file_path.name}"')

    def _write_chan_point_input(self, block: Block) -> None:
        """Writes the CHAN_POINT_INPUT command."""
        file_path = ihg_file_writer.write(
            self._gssha_file_path, self._stream_data, self._bc_comp.data, self._start_date_time
        )
        if file_path:
            block.write(group=None, name='CHAN_POINT_INPUT', value=f'"{file_path.name}"')

    def _write_precip_file(self, block: Block) -> None:
        """Writes the PRECIP_FILE command, which is used with hyetographs."""
        file_path = gag_file_writer.write(self._gssha_file_path, self._sim_comp.data, self._gm, self._start_date_time)
        if file_path:
            block.write(group=None, name='PRECIP_FILE', value=f'"{file_path.name}"')
            block.write(group=None, name='RAIN_INV_DISTANCE')

    def _get_outlet_row_col(self) -> tuple[int, int]:
        """Returns the outlet row and column cell numbers.

        The outlet cell is the lowest cell.
        """
        on_off_cells = data_util.get_on_off_cells(self._co_grid, self._ugrid)
        masked_array = io_util.create_masked_array(self._elevation_dataset, on_off_cells)
        cell_idx = masked_array.argmin()
        return self._co_grid.get_cell_ij_from_index(cell_idx)

    def _cell_size(self) -> float:
        """Returns the size of a grid cell."""
        return self._co_grid.locations_x[1] - self._co_grid.locations_x[0]
