"""Class for building a UGrid representation of a 3D Bridge."""
__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from io import StringIO
import json
import os
import shutil

# 2. Third party modules
import pandas as pd
from shapely.geometry import LineString

# 3. Aquaveo modules
from xms.constraint.ugrid_builder import UGridBuilder
from xms.core.filesystem import filesystem as xmf
from xms.data_objects.parameters import FilterLocation, UGrid as DoUGrid
from xms.guipy.data.target_type import TargetType
from xms.tool_core.exceptions import ToolError

# 4. Local modules
from xms.bridge.grid_builder import GridBuilder


def update_data_frame_arc_ids_from_component(df, comp):
    """Update the arc ids in a dataframe from a component.

    Args:
        df (:obj:`pandas.Dataframe`): the dataframe
        comp (:obj:`StructureComponent`): the structure

    Returns:
        (:obj:`pandas.Dataframe`): the updated dataframe
    """
    col_arc_id = df.columns[0]
    comp_ids = df[col_arc_id].tolist()
    arc_ids = []
    for ci in comp_ids:
        arc_id = comp.get_xms_ids(TargetType.arc, ci)
        arc_ids.append(arc_id[0] if arc_id != -1 else -1)
    df[col_arc_id] = arc_ids
    df = df[df[col_arc_id] != -1]
    return df


def _centerline_arc_id_from_component(struct_comp):
    """Get the centerline arc id from a component.

    Args:
        struct_comp (:obj:`StructureComponent`): the structure

    Returns:
        (:obj:`int`): the arc id
    """
    arc_id = -1
    sd = struct_comp.data.data_dict
    specify = sd.get('specify_arc_properties', 0)
    arc_prop = sd.get('arc_properties', '')
    struct_type = sd.get('structure_type', 'Bridge')
    if struct_type == 'Culvert':
        specify = sd.get('specify_culvert_arc_properties', 0)
        arc_prop = sd.get('culvert_arc_properties', '')

    if specify and arc_prop:
        df = pd.read_csv(StringIO(arc_prop))
        df = update_data_frame_arc_ids_from_component(df, struct_comp)
        col_arc_id, col_type = df.columns[0], df.columns[1]
        arc_ids = df[col_arc_id].tolist()
        arc_types = df[col_type].tolist()
        for i, at in enumerate(arc_types):
            if at == 'Bridge' or at == 'Embankment':
                arc_id = arc_ids[i]
                break

    return arc_id


def bridge_centerline_from_coverage(cover, struct_comp):
    """Get the bridge centerline arc from a coverage.

    Args:
        cover (:obj:`Coverage`): data_objects coverage
        struct_comp (:obj:`StructureComponent`): the structure

    Returns:
        (:obj:`tuple(Arc, list[x,y])`): data_objects arc, list of coordinates
    """
    ret_arc_pts = None
    arcs = cover.arcs if cover else []
    arc_dict = {arc.id: arc for arc in arcs}
    arc_id = _centerline_arc_id_from_component(struct_comp)
    ret_arc = arc_dict.get(arc_id, None)
    if ret_arc:
        ret_arc_pts = [(pt.x, pt.y) for pt in ret_arc.get_points(FilterLocation.PT_LOC_ALL)]
        return ret_arc, ret_arc_pts

    max_len = 0
    for arc in arcs:
        pts = [(pt.x, pt.y) for pt in arc.get_points(FilterLocation.PT_LOC_ALL)]
        ls = LineString(pts)
        if ls.length > max_len:
            ret_arc = arc
            ret_arc_pts = list(pts)
            max_len = ls.length
    return ret_arc, ret_arc_pts


def offset_centerline(bridge_cl, bridge_width):
    """Offset the bridge centerline using shapely.

    Args:
        bridge_cl (:obj:`list[x,y]`): points defining the centerline of the bridge
        bridge_width (:obj:`float`): width of the bridge

    Returns:
        (:obj:`tuple(list[x,y], list[x,y], str)`): right side offset, left side offset, message string if error
    """
    offset = bridge_width / 2.0
    if offset <= 0.0:
        msg = 'Bridge width must be greater than 0.0.'
        return None, None, msg

    ls = LineString(bridge_cl)
    up_arc = ls.parallel_offset(distance=offset, side='right')
    down_arc = ls.parallel_offset(distance=offset, side='left')
    up_arc = list(up_arc.coords) if type(up_arc) is LineString else None
    down_arc = list(down_arc.coords) if type(down_arc) is LineString else None
    # Haven't been able to find a case that goes into this "if" with Shapely > 2.0.
    # See test_dlg_misc in test_structure_dialog.py
    # if up_arc is None or len(up_arc) < 2 or down_arc is None or len(down_arc) < 2:
    #     msg = 'Error offsetting bridge center line.'
    #     return None, None, msg
    return up_arc, down_arc, ''


def compute_weir_elevation(struct_comp):
    """Compute the weir elevation for a weir bc associated with the structure.

    Args:
        struct_comp (:obj:`StructureComponent`): the structure

    Returns:
        (:obj:`float`): the weir elevation
    """
    return compute_ave_y_from_xy_data(struct_comp.data.curves['top_profile'])


def compute_ave_y_from_xy_data(df):
    """Compute the average y value from a dataframe with 2 columns.

    Args:
        df (:obj:`pandas.Dataframe`): xy data

    Returns:
        (:obj:`float`): the average of the second column
    """
    dist = df['Distance'].to_list()
    elev = df['Elevation'].to_list()
    ret_elev = elev[0]
    if len(dist) > 1:
        ret_elev = 0.0
        for i in range(1, len(dist)):
            dx = dist[i] - dist[i - 1]
            ave_elev = (elev[i] + elev[i - 1]) / 2
            ret_elev += dx * ave_elev
        total_dx = dist[-1] - dist[0]
        ret_elev = ret_elev / total_dx
    return ret_elev


def setup_structure_generation(xms_data=None):
    """Setup data on disk for use by structure tools.

    Args:
        xms_data (:obj:`XmsData`): Object for retrieving data from XMS

    Returns:
        (:obj:`dict`): data for use by structure tools
    """
    struct_cov = xms_data.coverage
    struct_comp = xms_data.structure_component
    cov_file = struct_cov.filename_and_group[0]
    tmp_dir = os.path.dirname(os.path.dirname(cov_file))
    dlg_dir = os.path.join(tmp_dir, 'structure_dlg')
    xmf.make_or_clear_dir(dlg_dir)
    covs_dir = os.path.join(dlg_dir, 'coverages')
    grids_dir = os.path.join(dlg_dir, 'grids')
    new_cov_file = os.path.join(covs_dir, f'{struct_cov.name}.h5')
    xmf.make_or_clear_dir(covs_dir)
    xmf.make_or_clear_dir(grids_dir)
    shutil.copyfile(cov_file, new_cov_file)
    wkt = xms_data.display_projection_wkt
    elev_file = ''
    if struct_comp.data.data_dict['elev_raster'] != '':
        elev_file = xms_data.raster_file_from_uuid(struct_comp.data.data_dict['elev_raster'])
        if elev_file is None:
            elev_file = ''
    arc_data = {
        'tmp_dir': dlg_dir,
        'cov_file': new_cov_file,
        'xms_data': xms_data,
        'coverage': struct_cov,
        'arc': None,
        'arc_pts': None,
        'main_file': struct_comp.main_file,
        'projection': xms_data.display_projection,
        'wkt': wkt,
        'vertical_units': xms_data.vertical_units,
        'elev_file': elev_file,
        'component': struct_comp
    }
    if struct_cov:
        arc, arc_pts = bridge_centerline_from_coverage(struct_cov, struct_comp)
        arc_data['arc'] = arc
        arc_data['arc_pts'] = arc_pts

    return arc_data


class _UGridBuilder:
    def __init__(self, struct_data, arc_data, footprint_calc):
        self.struct_data = struct_data
        self.arc_data = arc_data
        self.footprint_calc = footprint_calc
        self.arc_data['match_parametric_values'] = True
        if 'top_up' in arc_data:
            self.arc_data.pop('top_up')
        if 'top_dn' in arc_data:
            self.arc_data.pop('top_dn')
        self.err_msg = ''
        self.tool_err = ''

    def build_grids(self):
        self.check_errors_full_stop()
        if self.err_msg:
            return False, self.err_msg
        self.check_profiles()
        self.remove_old_ugrid_files()
        self.run_bridge_tool()
        if self.err_msg:
            return False, self.err_msg
        if self.tool_err:
            return False, self.tool_err
        self.build_3d_ugrid()
        return True, ''

    def check_errors_full_stop(self):
        if self.arc_data['arc'] is None:
            self.err_msg = 'Unable to generate UGrid. Structure requires a coverage with one arc.'
            return
        if self.struct_data.data_dict['bridge_width'] <= 0.0:
            self.err_msg = 'Unable to generate UGrid. Bridge width is set to 0.0.'
            return

    def check_profiles(self):
        if self.struct_data.curves['top_profile'].shape[0] < 2:
            self.err_msg = 'Unable to generate UGrid. Top profile line is missing.'
            return

        if self.struct_data.data_dict['structure_type'] == 'Bridge':
            for crv in self.struct_data.curve_names:
                if self.struct_data.curves[crv].shape[0] < 2:
                    self.err_msg = 'Unable to generate UGrid. Profile lines are missing.'
                    return

        grid_builder = GridBuilder()
        top = grid_builder.get_parametric_line(self.struct_data.curves['top_profile'])
        up = grid_builder.get_parametric_line(self.struct_data.curves['upstream_profile'])
        down = grid_builder.get_parametric_line(self.struct_data.curves['downstream_profile'])
        all_t = set(p[0] for p in top)
        all_t = all_t.union(set(p[0] for p in up))
        all_t = all_t.union(set(p[0] for p in down))
        ls_top = LineString(top)
        ls_up = LineString(up)
        ls_down = LineString(down)
        for t in all_t:
            top_elev = ls_top.interpolate(t).y
            up_elev = ls_up.interpolate(t).y
            down_elev = ls_down.interpolate(t).y
            if top_elev <= up_elev or top_elev <= down_elev:
                self.err_msg = 'Unable to generate UGrid. Low chord elevation profiles must be below the Crest profile.'

    def remove_old_ugrid_files(self):
        comp_dir = os.path.dirname(self.arc_data['main_file'])
        comp_files = ['bottom2d.prj', 'bottom2d.xmugrid', 'top2d.prj', 'top2d.xmugrid', 'ugrid.xmc', 'ugrid.prj']
        for f in comp_files:
            fname = os.path.join(comp_dir, f)
            if os.path.isfile(fname):
                os.remove(fname)

    def run_bridge_tool(self):
        # run the bridge tool
        self.footprint_calc.culvert_calc.err_msg = ''
        self.footprint_calc.setup_tool(self.struct_data)
        if self.footprint_calc.culvert_calc.err_msg:
            self.tool_err = 'Unable to generate UGrid. ' + self.footprint_calc.culvert_calc.err_msg
            return
        errors = self.footprint_calc.tool.validate_arguments(self.footprint_calc.args)
        if len(errors) < 1:
            try:
                self.footprint_calc.tool.run(self.footprint_calc.args)
                self.footprint_calc.load_tool_results(self.arc_data['elev_file'], self.arc_data)
            except (ToolError, RuntimeError) as e:
                self.tool_err = f'Unable to generate UGrid. {e}'
                return

        s_type = self.struct_data.data_dict['structure_type']
        if self.footprint_calc.bridge_mesh is None and s_type == 'Culvert':
            self.footprint_calc.calc_bridge_mesh_elevations(self.arc_data['elev_file'], self.arc_data)

    def build_3d_ugrid(self):
        up_arc, down_arc, msg = self.footprint_calc.get_up_down_arcs()
        # Haven't been able to find a case that goes into this "if" with Shapely > 2.0.
        # See test_dlg_misc in test_structure_dialog.py
        # if msg:
        #     return False, msg

        self.arc_data['up_stream'] = self.struct_data.curves['upstream_profile']
        self.arc_data['down_stream'] = self.struct_data.curves['downstream_profile']
        self.arc_data['top'] = self.struct_data.curves['top_profile']
        self.arc_data['up_arc'] = up_arc
        self.arc_data['down_arc'] = down_arc
        self.arc_data['bridge_mesh'] = self.footprint_calc.bridge_mesh
        self.arc_data['pier_elev'] = self.struct_data.data_dict['bridge_pier_base_elev']
        if 'top_up' in self.arc_data:
            self.arc_data['top'] = self.arc_data['top_up']

        grid_builder = GridBuilder()
        grid_builder.build_top_and_bottom_from_dict(self.arc_data)

        ug_file = os.path.join(os.path.dirname(self.arc_data['main_file']), 'ugrid.xmc')
        self.arc_data['ug_filename'] = ug_file
        ug3d = grid_builder.build_from_dict(self.arc_data)
        co_builder = UGridBuilder()
        co_builder.set_ugrid(ug3d)
        co_grid = co_builder.build_grid()
        co_grid.write_to_file(ug_file)
        cov = self.arc_data['coverage']
        self.arc_data['output_ugrid'] = DoUGrid(
            cogrid_file=ug_file, name=cov.name, projection=self.arc_data['projection']
        )

        if self.struct_data.data_dict['srh_mapping_info'] == '':
            srh_data = {
                'up_arc': up_arc,
                'down_arc': down_arc,
                'wkt': self.arc_data['wkt'],
                'culvert_poly': None,
            }
            self.struct_data.data_dict['srh_mapping_info'] = json.dumps(srh_data)


def generate_ugrids_from_structure(struct_data, arc_data, footprint_calc):
    """Generates the top, bottom, and 3d ugrids from the structure.

    Args:
        struct_data (:obj:`StructureData`): the structure data
        arc_data (:obj:`dict`): dict with various data values
        footprint_calc (:obj:`FootprintCalculator`): calculates the structure footprint

    Returns:
        (:obj:`tuple(bool, str)`): flag if successful, error msg
    """
    ug_builder = _UGridBuilder(struct_data, arc_data, footprint_calc)
    return ug_builder.build_grids()
