"""Module for the control file writing function."""

__copyright__ = "(C) Copyright Aquaveo 2024"
__license__ = "All rights reserved"

# 1. Standard Python modules
from datetime import datetime, timedelta
import logging
from pathlib import Path
import typing

# 2. Third party modules
from shapely import LineString

# 3. Aquaveo modules
from xms.data_objects.parameters import FilterLocation
from xms.gdal.utilities import gdal_utils as gu
from xms.gmi.data.generic_model import Group

# 4. Local modules
from xms.swmm.components.coverage_component import (get_link_groups, get_node_groups, StormDrainLinkComponent,
                                                    StormDrainNodeComponent)
from xms.swmm.data.model import (ConduitShape, FlowUnits, ForceMain, get_swmm_model, LengthUnits, OutfallType,
                                 RoutingModel)
from xms.swmm.dmi.xms_data import SwmmData


def write(xms_data: SwmmData, output_file: typing.Optional[str | Path] = None, logger: logging.Logger = None):
    """Writes the SWMM input file.

    Args:
        xms_data (SwmmData): Simulation data.
        output_file: The output filename.
        logger: The logger.
    """
    input_file_writer = InputFileWriter(xms_data, logger)
    return input_file_writer.write(output_file)


def _get_duplicate_names(group_dict: dict) -> list:
    duplicate_names = []
    for name, value in group_dict.items():
        if len(value) > 1:
            duplicate_names.append(name)
    return duplicate_names


def _fix_names(names: list, group_dict: dict):
    for name in names:
        cur_index = 1
        for group in group_dict[name][1:]:
            new_name = group[0].parameter('name').value.replace(' ', '_')
            while new_name in group_dict:
                new_name = f'{new_name}_{cur_index}'
                cur_index += 1
            group[0].parameter('name').value = new_name


class InputFileWriter:
    """Writes the SWMM input file."""

    def __init__(self, xms_data: SwmmData, logger: logging.Logger) -> None:
        """Initializes the class.

        Args:
            xms_data (SwmmData): Simulation data.
            logger: The logger.
        """
        super().__init__()
        self._log = logger or logging.getLogger('xms.swmm')
        self._log.info('Initializing data...')
        self._data = xms_data
        sim_data = self._data.sim_data
        self._global_parameters = get_swmm_model().global_parameters
        self._global_parameters.restore_values(sim_data.global_values)
        self._flow_units = self._global_parameters.group('general').parameter('flow_units').value
        length_units = LengthUnits.METERS
        if self._flow_units in [FlowUnits.CFS, FlowUnits.GPM, FlowUnits.MGD]:
            length_units = LengthUnits.FEET
        self._to_units = gu.UNITS_METERS if length_units == LengthUnits.METERS else gu.UNITS_FEET_US_SURVEY
        link_cov, self._link_component = self._data.link_coverage
        self._link_component: StormDrainLinkComponent
        self._link_coverage_parameters = get_swmm_model().arc_parameters
        self._arcs = link_cov.arcs
        node_cov, self._node_component = self._data.node_coverage
        self._node_component: StormDrainNodeComponent
        self._node_coverage_parameters = get_swmm_model().point_parameters
        self._points = node_cov.get_points(FilterLocation.PT_LOC_DISJOINT)
        # Get the start analysis date/time
        group = self._global_parameters.group('dates')
        date_str = group.parameter('start_analysis').value
        self._start_datetime = datetime.fromisoformat(date_str)
        self._point_locations = {}
        self._num_ts = 0
        self._tidal_curves = {}
        self._ts_curves = {}
        self._node_dict = {}
        self._link_dict = {}

    def write(self, output_file: [str | Path]):
        """Writes the SWMM input file."""
        if output_file is None:
            output_file = 'swmm.inp'
        else:
            output_file = str(output_file)
        with open(output_file, 'w') as f:
            self._log.info(f'Writing {output_file}...')
            self._resolve_duplicate_names()
            self._write_input_header(f)
            self._write_junctions(f)
            self._write_outfalls(f)
            self._write_conduits(f)
            self._write_pumps(f)
            self._write_orifices(f)
            self._write_weirs(f)
            self._write_outlets(f)
            self._write_xsections(f)
            self._write_inflows(f)
            self._write_dwf(f)
            self._write_curves(f)
            self._write_timeseries(f)
            self._write_coordinates(f)
            self._write_vertices(f)

    def _write_input_header(self, file: typing.TextIO):
        """
        Write the SWMM .inp file.

        Args:
            file: The file for writing.
        """
        self._log.info('Writing [TITLE]...')
        file.write('[TITLE]\n\n')
        self._log.info('Writing [OPTIONS]...')
        file.write('[OPTIONS]\n')
        group = self._global_parameters.group('general')
        file.write(f'FLOW_UNITS {self._flow_units}\n')
        file.write('INFILTRATION CURVE_NUMBER\n')
        routing_model = group.parameter('routing_model').value
        file.write('FLOW_ROUTING ')
        if routing_model == RoutingModel.STEADY:
            file.write('STEADY')
        elif routing_model == RoutingModel.KINWAVE:
            file.write('KINWAVE')
        else:
            file.write('DYNWAVE')
        file.write('\n')
        file.write('ALLOW_PONDING ')
        allow_podinig = group.parameter('allow_ponding').value
        if allow_podinig:
            file.write('YES')
        else:
            file.write('NO')
        file.write('\n')
        file.write('IGNORE_RAINFALL YES\n')
        group = self._global_parameters.group('dates')
        _write_date_and_time(file, group, 'start_analysis', 'START')
        _write_date_and_time(file, group, 'start_reporting', 'REPORT_START')
        _write_date_and_time(file, group, 'end_analysis', 'END')
        group = self._global_parameters.group('time_steps')
        _write_time_step(file, group.parameter('reporting_time_step').value * 60, 'REPORT_STEP')
        _write_time_step(file, group.parameter('dry_weather_time_step').value * 60, 'DRY_STEP')
        _write_time_step(file, group.parameter('wet_weather_time_step').value * 60, 'WET_STEP')
        _write_time_step(file, group.parameter('routing_time_step').value, 'ROUTING_STEP')
        group = self._global_parameters.group('dynamic_wave')
        force_main_equation = group.parameter('force_main_equation').value
        # Write this value so we match the SWMM5 interface default output
        file.write('INERTIAL_DAMPING PARTIAL\n')
        file.write('FORCE_MAIN_EQUATION ')
        if force_main_equation == ForceMain.H_W:
            file.write('H-W')
        else:
            file.write('D-W')
        # Write this value so we match the SWMM5 interface default output
        file.write('\n')
        file.write('VARIABLE_STEP 0.75')
        file.write('\n\n')
        file.write('[SUBCATCHMENTS]\n\n')
        file.write('[POLYGONS]\n\n')

    def _write_vertices(self, f: typing.TextIO):
        self._log.info('Writing [VERTICES]...')
        f.write('[VERTICES]\n')
        for arc in self._arcs:
            group = self._link_dict[arc.id]
            if group[0].is_active:
                name = group[0].parameter('name').value.replace(' ', '_')
                for vertex in arc.vertices:
                    f.write(f'{name} {vertex.x} {vertex.y}\n')

    def _write_coordinates(self, f: typing.TextIO):
        self._log.info('Writing [COORDINATES]...')
        f.write('[COORDINATES]\n')
        for point in self._points:
            group = self._node_dict[point.id]
            if group[0].is_active:
                name = group[0].parameter('name').value.replace(' ', '_')
                f.write(f'{name} {point.x} {point.y}\n')
        f.write('\n')

    def _write_timeseries(self, f: typing.TextIO):
        self._log.info('Writing [TIMESERIES]...')
        f.write('[TIMESERIES]\n')
        for name, curve in self._ts_curves.items():
            for value in zip(curve[0], curve[1]):
                cur_datetime = self._start_datetime + timedelta(minutes=value[0])
                f.write(f'{name} {cur_datetime.strftime("%m/%d/%Y")} {cur_datetime.strftime("%H:%M")} {value[1]}\n')
        f.write('\n')

    def _write_curves(self, f: typing.TextIO):
        self._log.info('Writing [CURVES]...')
        f.write('[CURVES]\n')
        for name, curve in self._tidal_curves.items():
            first = True
            for value in zip(curve[0], curve[1]):
                if first:
                    f.write(f'{name} Tidal {value[0]} {value[1]}\n')
                    first = False
                else:
                    f.write(f'{name}       {value[0]} {value[1]}\n')
        f.write('\n')

    def _write_dwf(self, f: typing.TextIO):
        # We're going to assume all dry weather inflows are zero for now. They are usually used for sanitary
        # sewer modeling.
        self._log.info('Writing [DWF]...')
        f.write('[DWF]\n')
        for point in self._points:
            group = self._node_dict[point.id]
            if group[0].is_active:
                name = group[0].parameter('name').value.replace(' ', '_')
                f.write(f'{name} FLOW 0\n')
        f.write('\n')

    def _write_inflows(self, f: typing.TextIO):
        self._log.info('Writing [INFLOWS]...')
        f.write('[INFLOWS]\n')
        for point in self._points:
            group = self._node_dict[point.id]
            if group[0].is_active:
                name = group[0].parameter('name').value.replace(' ', '_')
                f.write(f'{name} FLOW ')
                curve = group[0].parameter('inflows').value
                curve_name = ''
                if curve and (len(curve[1]) > 1 or curve[1][0] != 0.0):
                    self._num_ts += 1
                    curve_name = f'timeseries_{self._num_ts}'
                    self._ts_curves[curve_name] = curve
                f.write(f'"{curve_name}" FLOW 1.0 1.0 0\n')
        f.write('\n')

    def _write_xsections(self, f: typing.TextIO):
        self._log.info('Writing [XSECTIONS]...')
        f.write('[XSECTIONS]\n')
        for arc in self._arcs:
            group = self._link_dict[arc.id]
            if group[0].is_active:
                name = group[0].parameter('name').value.replace(' ', '_')
                f.write(f'{name} ')
                shape = group[0].parameter('shape').value
                height_diameter = group[0].parameter('height_diameter').value
                number_barrels = group[0].parameter('number_barrels').value
                filled_depth = group[0].parameter('filled_depth').value
                width = group[0].parameter('width').value
                if shape == ConduitShape.CIRCULAR:
                    f.write(f'CIRCULAR {height_diameter} 0 0 0 {number_barrels}')
                elif shape == ConduitShape.FILLED_CIRCULAR:
                    f.write(f'FILLED_CIRCULAR {height_diameter} {filled_depth} 0 0 {number_barrels}')
                elif shape == ConduitShape.CLOSED_RECTANGULAR:
                    f.write(f'CLOSED_RECTANGULAR {width} {height_diameter} 0 0 {number_barrels}')
                elif shape == ConduitShape.HORIZONTAL_ELLIPTICAL:
                    f.write(f'HORIZONTAL_ELLIPTICAL {width} {height_diameter} 0 0 {number_barrels}')
                elif shape == ConduitShape.VERTICAL_ELLIPTICAL:
                    f.write(f'VERTICAL_ELLIPTICAL {height_diameter} {width} 0 0 {number_barrels}')
                f.write('\n')
        f.write('\n')

    @staticmethod
    def _write_outlets(f):
        f.write('[OUTLETS]\n\n')

    @staticmethod
    def _write_weirs(f):
        f.write('[WEIRS]\n\n')

    @staticmethod
    def _write_orifices(f):
        f.write('[ORIFICES]\n\n')

    @staticmethod
    def _write_pumps(f):
        f.write('[PUMPS]\n\n')

    def _write_conduits(self, f: typing.TextIO):
        self._log.info('Writing [CONDUITS]...')
        f.write('[CONDUITS]\n')
        horiz_units = gu.get_horiz_unit_from_wkt(self._data.display_wkt)
        convert = gu.get_vertical_multiplier(horiz_units, self._to_units)
        for arc in self._arcs:
            group = self._link_dict[arc.id]
            if group[0].is_active:
                name = group[0].parameter('name').value.replace(' ', '_')
                arc_points = arc.get_points(FilterLocation.PT_LOC_ALL)
                pts = [[pt.x, pt.y, pt.z] for pt in arc_points]
                length = LineString(pts).length * convert
                roughness = group[0].parameter('roughness').value
                inlet_offset = group[0].parameter('inlet_offset').value
                outlet_offset = group[0].parameter('outlet_offset').value
                initial_flow = group[0].parameter('initial_flow').value
                max_flow = group[0].parameter('max_flow').value
                start_location = (arc.start_node.x, arc.start_node.y)
                end_location = (arc.end_node.x, arc.end_node.y)
                f.write(f'{name} {self._point_locations[start_location]} '
                        f'{self._point_locations[end_location]} {length} ')
                f.write(f'{roughness} {inlet_offset} {outlet_offset} {initial_flow} {max_flow}\n')
        f.write('\n')

    def _write_outfalls(self, f: typing.TextIO):
        self._log.info('Writing [OUTFALLS]...')
        f.write('[OUTFALLS]\n')
        num_tidal = 0
        self._tidal_curves = {}
        self._num_ts = 0
        self._ts_curves = {}
        convert = gu.get_vertical_multiplier(self._data.vertical_units, self._to_units)
        for point in self._points:
            group = self._node_dict[point.id]
            if group[1] == 'outfall':
                name = group[0].parameter('name').value.replace(' ', '_')
                self._point_locations[(point.x, point.y)] = name
                invert_elevation = point.z * convert
                outfall_type = group[0].parameter('type').value
                f.write(f'{name} {invert_elevation} {outfall_type} ')
                if outfall_type == OutfallType.TIDAL:
                    num_tidal += 1
                    curve_name = f'tidal_{num_tidal}'
                    f.write(f'{curve_name} ')
                    self._tidal_curves[curve_name] = group[0].parameter('tidal_outfall_curve').value
                elif outfall_type == OutfallType.FIXED:
                    f.write(f'''{group[0].parameter('fixed_stage').value} ''')
                elif outfall_type == OutfallType.TIMESERIES:
                    self._num_ts += 1
                    curve_name = f'timeseries_{self._num_ts}'
                    f.write(f'{curve_name} ')
                    self._ts_curves[curve_name] = group[0].parameter('timeseries_curve').value
                tide_gate = group[0].parameter('tide_gate').value
                if tide_gate:
                    f.write('YES')
                else:
                    f.write('NO')
                f.write('\n')
        f.write('\n')

    def _write_junctions(self, f: typing.TextIO):
        self._log.info('Writing [JUNCTIONS]...')
        f.write('[JUNCTIONS]\n')
        self._point_locations = {}
        convert = gu.get_vertical_multiplier(self._data.vertical_units, self._to_units)
        for point in self._points:
            group = self._node_dict[point.id]
            if group[1] == 'junction':
                name = group[0].parameter('name').value.replace(' ', '_')
                self._point_locations[(point.x, point.y)] = name
                invert_elevation = point.z * convert
                max_depth = group[0].parameter('max_depth').value
                f.write(f'{name} {invert_elevation} {max_depth}\n')
        f.write('\n')

    def _resolve_duplicate_names(self):
        node_dict = get_node_groups(self._points, self._node_component, self._node_coverage_parameters)
        link_dict = get_link_groups(self._arcs, self._link_component, self._link_coverage_parameters)
        dup_nodes = _get_duplicate_names(node_dict)
        dup_links = _get_duplicate_names(link_dict)
        if dup_nodes or dup_links:
            message = ''
            if dup_nodes:
                message = 'The following duplicate node names exist:'
                for dup_node in dup_nodes:
                    message += f' {dup_node}'
                message += '.'
            if dup_links:
                if dup_nodes:
                    message += ' '
                message += 'The following duplicate link names exist:'
                for dup_link in dup_links:
                    message += f' {dup_link}'
                message += '.'
            message += ('\nYou cannot save a SWMM file with duplicate node and/or link names.\nThe names were '
                        'automatically fixed in the SWMM input file.')
            self._log.warning(message)
            if dup_nodes:
                _fix_names(dup_nodes, node_dict)
            if dup_links:
                _fix_names(dup_links, link_dict)
        for groups in node_dict.values():
            for group in groups:
                self._node_dict[group[2]] = (group[0], group[1])
        for groups in link_dict.values():
            for group in groups:
                self._link_dict[group[2]] = (group[0], group[1])


def _write_time_step(file: typing.TextIO, total_seconds: float, card: str):
    hours, remainder = divmod(int(total_seconds), 3600)
    minutes, seconds = divmod(int(remainder), 60)
    file.write(f'{card} {hours:02}:{minutes:02}:{seconds:02}\n')


def _write_date_and_time(file: typing.TextIO, group: Group, gmi_name: str, card: str):
    date_str = group.parameter(gmi_name).value
    date = datetime.fromisoformat(date_str)
    file.write(f'{card}_DATE {date.strftime("%m/%d/%Y")}\n')
    file.write(f'{card}_TIME {date.strftime("%H:%M:%S")}\n')
