"""Writer for ADCIRC fort.14 geometry files."""

# 1. Standard Python modules
from io import StringIO
import shutil

# 2. Third party modules

# 3. Aquaveo modules
from xms.constraint import read_grid_from_file
from xms.constraint.ugrid_boundaries import UGridBoundaries
from xms.grid.ugrid import UGrid as XmUGrid

# 4. Local modules
from xms.adcirc.data import bc_data as bcd
from xms.adcirc.data.mapped_bc_data import MappedBcData
from xms.adcirc.feedback.xmlog import XmLog
from xms.adcirc.file_io.grid_crc import compute_grid_crc

IBTYPE_STR_MAP = {
    bcd.UNASSIGNED_INDEX: 'Unassigned',
    bcd.OCEAN_INDEX: 'Ocean',
    bcd.MAINLAND_INDEX: 'Mainland',
    bcd.ISLAND_INDEX: 'Island',
    bcd.RIVER_INDEX: 'River',
    bcd.LEVEE_OUTFLOW_INDEX: 'Levee outflow',
    bcd.LEVEE_INDEX: 'Levee',
    bcd.RADIATION_INDEX: 'Radiation',
    bcd.ZERO_NORMAL_INDEX: 'Zero normal',
    bcd.FLOW_AND_RADIATION_INDEX: 'Flow and radiation',
}


def ibtype_idx_to_text(type_list):
    """Convert list of ibtype combobox indices to human-friendly text.

    Args:
        type_list (list[int]):

    Returns:
        list[int]: The stringified ibtype
    """
    return [IBTYPE_STR_MAP[index] for index in type_list]


def export_mapped_bc_data(filename, cogrid_file, grid_name, mapped_bc_datas):
    """Export the fort.14.

    Args:
        filename (:obj:`str`): Output file location
        cogrid_file (:obj:`str`): Path to the domain CoGrid file
        grid_name (:obj:`str`): Name of the ADCIRC mesh tree item
        mapped_bc_datas (list[str]): Absolute paths to the simulation's mapped boundary conditions datas.
    """
    grid_crc: str = compute_grid_crc(cogrid_file)
    mapped_bc_datas = [MappedBcData(main_file.strip("'\"")) for main_file in mapped_bc_datas]

    for mapped_bc_data in mapped_bc_datas:
        if mapped_bc_data.info.attrs.get('grid_crc', '') != grid_crc:
            XmLog().instance.error(
                'Cannot export boundary conditions. The mesh was edited or replaced since the BC coverage was applied. '
                'Remove the applied BC coverage and apply a new one to export.'
            )
            return

    ss = StringIO()
    # Export the domain mesh
    XmLog().instance.info('Exporting mesh geometry.')
    boundary_loops = export_geometry_to_fort14(ss, cogrid_file, grid_name)
    # Export the mapped BC atts
    XmLog().instance.info('Exporting boundary conditions.')
    export_boundary_conditions_to_fort14(ss, mapped_bc_datas, boundary_loops)

    # Write in-memory stream to file
    XmLog().instance.info('Saving in-memory stream to file.')
    with open(filename, 'w') as f:
        ss.seek(0)
        shutil.copyfileobj(ss, f, 1000000)
    XmLog().instance.info('Successfully exported fort.14 file.')


def export_geometry_to_fort14(ss, grid_file, grid_name):
    """Export grid geometry to the fort.14.

    Args:
        ss (:obj:`file-like object`): Stream to write to.
        grid_file (:obj:`str`): Path to the xmsconstraint file containing the ADCIRC mesh.
        grid_name (:obj:`str`): Name of the ADCIRC mesh tree item

    Returns:
        (dict): The node boundaries of the mesh
    """
    XmLog().instance.info('Reading mesh file created by SMS.')
    co_grid = read_grid_from_file(grid_file)
    ugrid = co_grid.ugrid
    XmLog().instance.info('Finding boundary nodes of mesh.')
    boundaries = UGridBoundaries(grid=ugrid, target_cells=False, cell_values=None)
    loops = boundaries.get_loops()
    XmLog().instance.info('Converting SMS mesh to ADCIRC format.')
    ss.write(f'{grid_name}\n')
    ss.write(f'{ugrid.cell_count} {ugrid.point_count}\n')
    # Write all the node definitions
    XmLog().instance.info('Writing mesh node locations.')
    pts = ugrid.locations
    for index in range(len(pts)):
        ss.write(f'{index+1:10d} {pts[index][0]:15.10f} {pts[index][1]:15.10f} {pts[index][2] * -1:15.10f}\n')
    # Write all the cell definitions
    XmLog().instance.info('Writing mesh cell definitions.')
    cell_stream = ugrid.cellstream
    cell_id = 0
    cell_stream_size = len(cell_stream)
    step_size = 5
    for index in range(0, cell_stream_size, step_size):
        cell_id += 1
        if cell_stream[index] != XmUGrid.cell_type_enum.TRIANGLE:
            raise RuntimeError(f'Invalid triangle found in domain mesh: Cell ID={cell_id}')
        # ADCIRC currently only supports triangular elements.
        ss.write(
            f'{cell_id:5d} 3 {cell_stream[index+2] + 1:5d} {cell_stream[index+3] + 1:5d} '
            f'{cell_stream[index+4] + 1:5d}\n'
        )
    return loops


def export_boundary_conditions_to_fort14(ss, mapped_bc_datas, boundary_loops):
    """Export boundary condition nodestrings to the fort.14.

    Args:
        ss (file-like object): Stream to write to.
        mapped_bc_datas (list[MappedBcData]): The simulation's mapped boundary conditions data.
        boundary_loops (dict): The node boundaries of the mesh
    """
    class NodestringData:
        """Struct to pass along nodestring data."""
        ocean_lines = []
        land_lines = []
        generic_lines = []
        num_ocean_boundaries = 0  # Number of open boundaries on all coverages
        num_ocean_nodes = 0  # Number of open boundary nodes on all coverages
        num_land_boundaries = 0  # Number of land boundaries on all coverages
        num_land_nodes = 0  # Number of land boundary nodes on all coverages
        num_generic_boundaries = 0  # Number of generic nodestrings on all coverages
        num_generic_nodes = 0  # Number of generic boundary nodes on all coverages
        ocean_header = ''
        land_header = ''
        generic_header = ''
        # Flatten the assigned nodes map.
        assigned_nodes = {}
        for loop_id in boundary_loops:
            for node in boundary_loops[loop_id]['id']:
                assigned_nodes[node] = []

    # Let's begin.
    data = NodestringData()
    for mapped_bc in mapped_bc_datas:
        # Store all the nodestrings for a particular boundary coverage. We will then export all coverages in the
        # fort.14. I believe we should indicate whether the mapped item is the "main" or "base" append item.
        _gather_ocean_lines(mapped_bc, data)
        ibtypes = _gather_nonlevee_lines(mapped_bc, data)
        _gather_levee_lines(mapped_bc, data, ibtypes)
        _gather_generic_lines(mapped_bc, data)

    # These are the header lines for each type of boundary. Can't write them until we have totals from all coverages.
    data.ocean_header = (
        f'{data.num_ocean_boundaries} = Number of open boundaries\n'
        f'{data.num_ocean_nodes} = Total number of open boundary nodes\n'
    )
    data.land_header = (
        f'{data.num_land_boundaries} = Number of land boundaries\n'
        f'{data.num_land_nodes} = Total number of land boundary nodes\n'
    )
    data.generic_header = (
        f'{data.num_generic_boundaries} = Number of generic boundaries\n'
        f'{data.num_generic_nodes} = Total number of generic boundary nodes\n'
    )
    _check_node_assignments(data)
    _dump_to_stream(ss, data)


def _dump_to_stream(ss, data):
    """Write all the gathered lines to the fort.14 stream.

    Args:
        ss (StringIO): The stream to write to.
        data (NodestringData): The gathered lines to write.
    """
    ss.write(data.ocean_header)
    ss.write(''.join([line for line in data.ocean_lines]))
    ss.write(data.land_header)
    ss.write(''.join([line for line in data.land_lines]))
    if data.generic_lines:  # Don't worry about these if there are none. Not used by ADCIRC.
        ss.write(data.generic_header)
        ss.write(''.join([line for line in data.generic_lines]))


def _gather_ocean_lines(mapped_bc, data):
    """Get all the lines for open boundaries.

    Args:
        mapped_bc (MappedBcData): The mapped BC component.
        data (NodestringData): Struct to fill in.
    """
    ocean_nodes, node_counts = mapped_bc.get_ocean_node_ids()
    data.num_ocean_nodes += sum(node_counts)
    node_start_idx = 0
    for _, node_count in enumerate(node_counts):
        data.num_ocean_boundaries += 1
        data.ocean_lines.append(f'{node_count} = Number of nodes for open boundary {data.num_ocean_boundaries}\n')
        for node_idx in ocean_nodes:
            data.assigned_nodes[node_idx - 1].append(bcd.OCEAN_INDEX)
        data.ocean_lines.append(
            ''.join([f'{ocean_nodes[node_start_idx + node_idx]}\n' for node_idx in range(node_count)])
        )
        node_start_idx += node_count


def _gather_nonlevee_lines(mapped_bc, data):
    """Get all the lines for non-levee land boundaries.

    Args:
        mapped_bc (MappedBcData): The mapped BC component.
        data (NodestringData): Struct to fill in.

    Returns:
        dict: Mapping of component id to boundary type
    """
    ibtypes = mapped_bc.source_data.get_ib_types()
    land_nodes, node_counts, comp_ids = mapped_bc.get_nonlevee_land_node_ids()
    node_start_idx = 0
    node_total = 0
    for comp_id, node_count in zip(comp_ids, node_counts):
        data.num_land_boundaries += 1
        data.num_land_nodes += node_count
        make_self_closing = False
        ibtype = mapped_bc.source_data.arcs.type.loc[comp_id].item()
        tail_node = None if len(land_nodes) >= (node_start_idx + node_count) - 1 \
            else land_nodes(node_start_idx + node_count) - 1
        if tail_node and land_nodes[node_start_idx] != tail_node:
            # Check if this is an island
            if ibtype == bcd.ISLAND_INDEX:
                make_self_closing = True
        node_count = node_count if not make_self_closing else node_count + 1
        data.land_lines.append(
            # If we are going to repeat the start node to close a loop, increment the node count we
            # write for this boundary.
            f'{node_count if not make_self_closing else node_count + 1} '
            f'{ibtypes[comp_id]} = Number of nodes for land boundary {data.num_land_boundaries}, IBTYPE\n'
        )
        for node_idx in range(node_count):
            node_id = land_nodes[node_start_idx + node_idx]
            # Mark the node as assigned
            data.assigned_nodes[node_id - 1].append(int(ibtype))
            data.land_lines.append(f'{node_id}\n')
        if make_self_closing:  # Ensure closed island loops
            # Mark the node as assigned
            node_id = land_nodes[node_start_idx]
            data.assigned_nodes[node_id - 1].append(int(ibtype))
            data.num_land_nodes += 1
            data.land_lines.append(f'{node_id}\n')
            node_total += 1
        node_start_idx += node_count
        node_total += node_count
    return ibtypes


def _gather_levee_lines(mapped_bc, data, ibtypes):
    """Get all the lines for levee boundaries.

    Args:
        mapped_bc (MappedBcData): The mapped BC component.
        data (NodestringData): Struct to fill in.
        ibtypes (dict): The IBTYPES for the levees.
    """
    levee_comp_ids = set(mapped_bc.levees.comp_id.data.tolist())
    for _, levee_comp in enumerate(levee_comp_ids):
        ibtype = ibtypes[levee_comp]
        ibtype_simple = abs(ibtype) % 10
        levee_dset = mapped_bc.levees.sel(comp_id=levee_comp)
        node_count = levee_dset.sizes['comp_id']
        data.num_land_boundaries += 1
        if ibtype_simple == 4:  # Check if there are any pipes in this levee pair
            pipe_ons = levee_dset['Pipe'].data.tolist()
            if any(pipe_ons):  # If any of the pipe toggles are on, this levee pair has pipes
                ibtype = ibtype + 1  # Change IBTYPE so it ends in 5
                ibtype_simple = 5
        data.land_lines.append(
            f'{node_count} {ibtype} = Number of nodes for land boundary {data.num_land_boundaries}, IBTYPE\n'
        )
        if ibtype_simple == 3:  # Levee outflow
            for i in range(node_count):
                # NodeId, Zcrest, Supercritical coef
                node_id = levee_dset['Node1 Id'].data[i].item()
                # Mark the node as assigned
                data.assigned_nodes[node_id - 1].append(bcd.LEVEE_OUTFLOW_INDEX)
                data.land_lines.append(
                    f"{node_id:<10d} {levee_dset['Zcrest (m)'].data[i].item():<10.3f} "
                    f"{levee_dset['Supercritical __new_line__ Flow Coef'].data[i].item():<0.3f}\n"
                )
            data.num_land_nodes += node_count
        elif ibtype_simple == 4:  # Interior levee pair or pipe pair
            for i in range(node_count):
                # Node1Id, Node2Id, Zcrest, Subcritical coef, Supercritical coef
                node1 = levee_dset['Node1 Id'].data[i].item()
                node2 = levee_dset['Node2 Id'].data[i].item()
                # Mark the nodes as assigned
                data.assigned_nodes[node1 - 1].append(bcd.LEVEE_INDEX)
                data.assigned_nodes[node2 - 1].append(bcd.LEVEE_INDEX)
                data.land_lines.append(
                    f"{node1:<10d} {node2:<10d} "
                    f"{levee_dset['Zcrest (m)'].data[i].item():<10.3f} "
                    f"{levee_dset['Subcritical __new_line__ Flow Coef'].data[i].item():<10.3f} "
                    f"{levee_dset['Supercritical __new_line__ Flow Coef'].data[i].item():<0.3f}\n"
                )
            data.num_land_nodes += (node_count * 2)
        elif ibtype_simple == 5:  # Interior levee pair with pipes
            for i in range(node_count):
                # Node1Id, Node2Id, Zcrest, Sub coef, Super coef, pipe height, pipe coef, pipe diameter
                height = 100.0
                if levee_dset['Pipe'].data[i].item():
                    height = levee_dset['Zpipe (m)'].data[i].item()
                # Mark the nodes as assigned
                node1 = levee_dset['Node1 Id'].data[i].item()
                node2 = levee_dset['Node2 Id'].data[i].item()
                data.assigned_nodes[node1 - 1].append(bcd.LEVEE_INDEX)
                data.assigned_nodes[node2 - 1].append(bcd.LEVEE_INDEX)
                data.land_lines.append(
                    f"{node1:<10d} {node2:<10d} "
                    f"{levee_dset['Zcrest (m)'].data[i].item():<10.3f} "
                    f"{levee_dset['Subcritical __new_line__ Flow Coef'].data[i].item():<10.3f} "
                    f"{levee_dset['Supercritical __new_line__ Flow Coef'].data[i].item():<10.3f} "
                    f"{height:<10.3f} "
                    f"{levee_dset['Bulk __new_line__ Coefficient'].data[i].item():<10.3f} "
                    f"{levee_dset['Pipe __new_line__ Diameter (m)'].data[i].item():<0.3f}\n"
                )
            data.num_land_nodes += (node_count * 2)


def _gather_generic_lines(mapped_bc, data):
    """Get all the lines for generic lines.

    Args:
        mapped_bc (MappedBcData): The mapped BC component
        data (NodestringData): Struct to fill in.
    """
    # Export the generic arcs - do we really need to do this? I don't think ADCIRC reads them, but we do.
    generic_nodes, node_counts = mapped_bc.get_generic_node_ids()
    if node_counts:
        node_start_idx = 0
        for _, node_count in enumerate(node_counts):
            data.num_generic_boundaries += 1
            data.num_generic_nodes += node_count
            data.generic_lines.append(
                f'{node_count} = Number of nodes for generic boundary {data.num_generic_boundaries}\n'
            )
            for node_idx in range(node_count):
                data.generic_lines.append(f'{generic_nodes[node_start_idx + node_idx]}\n')
            node_start_idx += node_count


def _check_node_assignments(data):
    """Checks that all the boundary nodes have been properly defined.

    Args:
        data (NodestringData): Struct to fill in.
    """
    for node_idx, type_list in data.assigned_nodes.items():
        # Check for unassigned nodes.
        if not type_list:
            XmLog.instance.warning(f'Boundary node {node_idx + 1} is unassigned. Defaulting to Mainland.')
            f'1 0 = Number of nodes for land boundary {data.num_land_boundaries}, IBTYPE\n'
            data.num_land_boundaries += 1
            data.num_land_nodes += 1
        # Check for multiple assignments
        elif len(type_list) > 1 and bcd.MAINLAND_INDEX not in type_list and bcd.ISLAND_INDEX not in type_list:
            for _ in type_list:
                should_warn = True
                for ib_type in type_list:
                    if ib_type not in [bcd.LEVEE_INDEX, bcd.LEVEE_OUTFLOW_INDEX]:
                        should_warn = False
                        break
                msg = f'Node {node_idx + 1} has multiple boundary assignments {ibtype_idx_to_text(type_list)}'
                if should_warn:
                    # Report as a warning if a levee. Users shouldn't do this, but they do all the time.
                    XmLog.instance.warning(f'{msg}')
                else:
                    XmLog.instance.error(f'{msg}')
                break
