"""BridgeFootprintTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from typing import List

# 2. Third party modules
from geopandas import GeoDataFrame
from shapely.geometry import LineString, Point

# 3. Aquaveo modules
from xms.constraint import UGridBuilder
from xms.tool_core import Argument, IoDirection, Tool

# 4. Local modules
from xms.tool.algorithms.mesh_2d.bridge_footprint import (ArcType, BridgeFootprint, PierEndType, PierType)


def _get_arcs_from_coverage(coverage: GeoDataFrame) -> list:
    """Get the arcs from a coverage.

    Args:
        coverage (GeoDataFrame): The coverage.

    Returns:
        (list): A list of dictionaries of 'id', 'arc_pts', 'start_node', and 'end_node'.
    """
    arcs = []
    coverage_arcs = coverage[coverage['geometry_types'] == 'Arc']
    for arc in coverage_arcs.itertuples():
        pts = list(arc.geometry.coords)
        arc_dict = {
            'id': arc.id,
            'arc_pts': pts,
            'start_node': arc.start_node,
            'end_node': arc.end_node
        }
        arcs.append(arc_dict)
    return arcs


class BridgeFootprintTool(Tool):
    """Tool to create a mesh or coverage for bridge piers."""
    ARG_INPUT_COVERAGE = 0
    # Bridge
    ARG_BRIDGE_WIDTH = 1
    ARG_BRIDGE_WRAPPING_WIDTH = 2
    ARG_SPECIFY_SEGMENT_COUNT = 3
    ARG_SEGMENT_COUNT = 4
    # Abutments
    ARG_HAS_ABUTMENTS = 5
    ARG_PIER_TYPE = 6
    # Group Pier
    ARG_PIER_DIAMETER = 7
    ARG_GROUP_WRAPPING_WIDTH = 8
    ARG_PIER_GROUP_COUNT = 9
    ARG_PIER_GROUP_SPACING = 10
    # Wall Pier
    ARG_WALL_WIDTH = 11
    ARG_WALL_WRAPPING_WIDTH = 12
    ARG_WALL_PIER_LENGTH = 13
    ARG_SIDE_ELEMENT_COUNT = 14
    ARG_PIER_END_TYPE = 15
    # Output
    ARG_OUTPUT_GRID = 16
    ARG_OUTPUT_COVERAGE = 17

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Bridge Footprint')
        self.center_line_num_segments = None
        self.bridge_upstream_line = None
        self.bridge_downstream_line = None
        self.coverage_elevation = 0.0
        self.arc_data = None
        self.arc_id_to_type = None
        self.wrap_upstream = True
        self.wrap_downstream = True
        self.bridge_num_side_elem = None

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.coverage_argument(name='input_coverage', description='Input coverage'),

            # bridge settings
            self.float_argument(name='bridge_width', description='Bridge width', min_value=0.0, value=0.0),
            self.float_argument(name='bridge_wrapping_width', description='Bridge wrapping width', min_value=0.0,
                                value=0.0),
            self.bool_argument(name='specify_segment_count', description='Specify number of segments', value=False),
            self.integer_argument(name='segment_count', description='Number of segments', min_value=1, optional=True,
                                  value=1),
            self.bool_argument(name='has_abutments', description='Has abutments', value=False),

            # piers
            self.string_argument(name='pier_type', description='Pier type',
                                 choices=[Argument.NONE_SELECTED, PierType.GROUP.value, PierType.WALL.value],
                                 optional=True),

            # group pier
            self.float_argument(name='pier_diameter', description='Pier diameter', min_value=0.0, optional=True),
            self.float_argument(name='group_wrapping_width', description='Element wrapping width', min_value=0.0,
                                optional=True),
            self.integer_argument(name='pier_group_count', description='Number of piers in group', value=1,
                                  min_value=1, optional=True),
            self.float_argument(name='pier_group_spacing', description='Pier group spacing', min_value=0.0,
                                optional=True),

            # wall pier
            self.float_argument(name='wall_width', description='Wall width', min_value=0.0, optional=True),
            self.float_argument(name='element_wrapping_width', description='Element wrapping width', min_value=0.0,
                                optional=True),
            self.float_argument(name='wall_pier_length', description='Wall pier length', min_value=0.0, optional=True),
            self.integer_argument(name='side_element_count', description='Wall pier number of side elements',
                                  min_value=0, optional=True),
            self.string_argument(name='pier_end_type', description='Pier end type',
                                 choices=[PierEndType.SQUARE.value, PierEndType.ROUND.value, PierEndType.SHARP.value],
                                 optional=True),

            self.grid_argument(name='output_grid', description='Output grid', io_direction=IoDirection.OUTPUT,
                               optional=True),
            self.coverage_argument(name='output_coverage', description='Output coverage',
                                   io_direction=IoDirection.OUTPUT, optional=True)
        ]
        self.enable_arguments(arguments)
        return arguments

    def enable_arguments(self, arguments: List[Argument]):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.
        """
        # show/hide segment count
        arguments[self.ARG_SEGMENT_COUNT].show = arguments[self.ARG_SPECIFY_SEGMENT_COUNT].value

        # show/hide pier group arguments
        type_is_group = arguments[self.ARG_PIER_TYPE].value == PierType.GROUP.value
        arguments[self.ARG_PIER_DIAMETER].show = type_is_group
        arguments[self.ARG_GROUP_WRAPPING_WIDTH].show = type_is_group
        arguments[self.ARG_PIER_GROUP_COUNT].show = type_is_group
        arguments[self.ARG_PIER_GROUP_SPACING].show = type_is_group

        # show/hide pier wall arguments
        type_is_wall = arguments[self.ARG_PIER_TYPE].value == PierType.WALL.value
        arguments[self.ARG_WALL_WIDTH].show = type_is_wall
        arguments[self.ARG_WALL_WRAPPING_WIDTH].show = type_is_wall
        arguments[self.ARG_WALL_PIER_LENGTH].show = type_is_wall
        arguments[self.ARG_SIDE_ELEMENT_COUNT].show = type_is_wall
        arguments[self.ARG_PIER_END_TYPE].show = type_is_wall

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}
        if arguments[self.ARG_BRIDGE_WIDTH].value <= 0.0:
            errors[arguments[self.ARG_BRIDGE_WIDTH].name] = 'Bridge width must be greater than zero.'
        if arguments[self.ARG_BRIDGE_WRAPPING_WIDTH].value <= 0.0:
            errors[arguments[self.ARG_BRIDGE_WRAPPING_WIDTH].name] = 'Bridge width must be greater than zero.'

        if arguments[self.ARG_PIER_TYPE].value == 'Group':
            if arguments[self.ARG_PIER_DIAMETER].value <= 0.0:
                errors[arguments[self.ARG_PIER_DIAMETER].name] = 'Pier diameter must be greater than zero.'
            idx = self.ARG_GROUP_WRAPPING_WIDTH
            if arguments[idx].value <= 0.0:
                errors[arguments[idx].name] = 'Pier group wrapping width must be greater than zero.'
            idx = self.ARG_PIER_GROUP_COUNT
            if arguments[self.ARG_PIER_GROUP_SPACING].value <= 0.0:
                errors[arguments[self.ARG_PIER_GROUP_SPACING].name] = 'Pier group spacing must be greater than zero.'
        elif arguments[self.ARG_PIER_TYPE].value == 'Wall':
            if arguments[self.ARG_WALL_WIDTH].value <= 0.0:
                errors[arguments[self.ARG_WALL_WIDTH].name] = 'Pier wall width must be greater than zero.'
            idx = self.ARG_WALL_WRAPPING_WIDTH
            if arguments[idx].value <= 0.0:
                errors[arguments[idx].name] = 'Pier wall wrapping width must be greater than zero.'
            if arguments[self.ARG_WALL_PIER_LENGTH].value <= 0.0:
                errors[arguments[self.ARG_WALL_PIER_LENGTH].name] = 'Pier wall length must be greater than zero.'
            idx = self.ARG_SIDE_ELEMENT_COUNT
            if arguments[idx].value < 1:
                errors[arguments[idx].name] = 'Pier wall number of side elements must be greater than zero.'

        # must have an output
        if not arguments[self.ARG_OUTPUT_GRID].value and not arguments[self.ARG_OUTPUT_COVERAGE].value:
            errors[arguments[self.ARG_OUTPUT_GRID].name] = 'Must have either an output grid or an output coverage.'
            errors[arguments[self.ARG_OUTPUT_COVERAGE].name] = 'Must have either an output grid or an output coverage.'
        return errors

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        input_coverage = self.get_input_coverage(arguments[self.ARG_INPUT_COVERAGE].value)
        output_grid = arguments[self.ARG_OUTPUT_GRID].value
        output_coverage = arguments[self.ARG_OUTPUT_COVERAGE].value
        if self.arc_data is None:
            self.arc_data = self.get_arc_data(input_coverage, arguments)
        inputs = self.get_inputs(arguments)
        if self.arc_data is not None:
            worker = BridgeFootprint(self.arc_data, inputs, self.logger)
            worker.coverage_elevation = self.coverage_elevation
            try:
                worker.generate()
            except RuntimeError as err:
                self.fail(str(err))
            if output_grid:
                new_grid = _co_grid_from_ugrid(worker.out_ugrid)
                self.set_output_grid(new_grid, arguments[self.ARG_OUTPUT_GRID], force_ugrid=False)
                self.bridge_upstream_line = worker.bridge_upstream_line
                self.bridge_downstream_line = worker.bridge_downstream_line
            if output_coverage:
                coverage = worker.out_coverage
                self.set_output_coverage(coverage, arguments[self.ARG_OUTPUT_COVERAGE])

    def get_inputs(self, arguments: List[Argument]) -> dict:
        """Get the inputs dictionary for BridgeFootprint.

        Args:
            arguments (List[Argument]): The arguments.

        Returns:
            (dict): Inputs for BridgeFootprint.
        """
        inputs = dict()
        wkt = '' if self.default_wkt is None else self.default_wkt
        inputs['wkt'] = wkt
        inputs['new_coverage_name'] = arguments[self.ARG_OUTPUT_COVERAGE].value
        inputs['create_mesh'] = bool(arguments[self.ARG_OUTPUT_GRID].text_value)
        inputs['center_line_num_segments'] = self.center_line_num_segments
        inputs['wrap_upstream'] = self.wrap_upstream
        inputs['wrap_downstream'] = self.wrap_downstream
        inputs['bridge_num_side_elem'] = self.bridge_num_side_elem
        return inputs

    def get_arc_data(self, coverage: GeoDataFrame, arguments: List[Argument]):
        """Get the arc data for a coverage.

        Args:
            coverage (GeoDataFrame): The coverage.
            arguments (List[Argument]): The arguments.

        Returns:
            (Optional[list]): Arc data for BridgeFootprint.
        """
        arcs = _get_arcs_from_coverage(coverage)
        return self.build_arc_data(arcs, arguments)

    def build_arc_data(self, arcs: List[dict], arguments: List[Argument]) -> dict:
        """Build arc data dictionary from arcs and arguments.

        Args:
            arcs (dict): The arcs already extracted from the coverage.
            arguments (List[Arguments]): The arguments.

        Returns:
            (dict): The arc data ready to be used by MeshBridgeFromPiers.
        """
        if len(arcs) == 0:
            self.fail("The coverage doesn't contain any arcs.")

        # add shapely line string and length
        for arc in arcs:
            line_string = LineString(arc['arc_pts'])
            arc['line_string'] = line_string
            arc['length'] = line_string.length

        # find bridge arc
        arcs = sorted(arcs, key=lambda k: k['length'])
        if self.arc_id_to_type:
            id_to_arc_item = {aa['id']: (aa, idx) for idx, aa in enumerate(arcs)}
            for k, v in self.arc_id_to_type.items():
                if v == 'Bridge':
                    bridge_arc, idx = id_to_arc_item.get(k, (None, -1))
                    if idx > -1:
                        arcs.pop(idx)
                    break
        else:
            bridge_arc = arcs[-1]
            arcs = arcs[:-1]
        bridge_arc['arc_type'] = ArcType.BRIDGE
        bridge_arc['bridge_width'] = arguments[self.ARG_BRIDGE_WIDTH].value
        bridge_arc['bridge_wrapping_width'] = arguments[self.ARG_BRIDGE_WRAPPING_WIDTH].value
        bridge_arc['bridge_num_segments'] = arguments[self.ARG_SEGMENT_COUNT].value
        bridge_arc['bridge_specify_num_segments'] = 1 if arguments[self.ARG_SPECIFY_SEGMENT_COUNT].value else 0
        arc_data = [bridge_arc]

        # make sure all other arcs cross the bridge arc
        bridge_line_string = bridge_arc['line_string']
        for arc in arcs:
            crossing = bridge_line_string.intersection(arc['line_string'])
            arc_id = arc['id']
            if not isinstance(crossing, Point):
                message = f'Arc ID {arc_id} fails to cross the bridge arc. All arcs must cross the bridge arc.'
                self.fail(message)
            t_value = crossing.distance(Point(bridge_arc['arc_pts'][0])) / bridge_arc['length']
            arc['t_value'] = t_value

        # find abutment arcs
        if arguments[self.ARG_HAS_ABUTMENTS].value:
            arcs = sorted(arcs, key=lambda k: k['t_value'])
            if len(arcs) < 2:
                message = 'Must have at least two arcs crossing the bridge to have abutments.'
                self.fail(message)
            abutment_1 = arcs[0]
            abutment_1['arc_type'] = ArcType.ABUTMENT
            abutment_2 = arcs[-1]
            abutment_2['arc_type'] = ArcType.ABUTMENT
            arc_data.extend([abutment_1, abutment_2])
            arcs = arcs[1:-1]

        # add pier data
        if len(arcs) > 0 and arguments[self.ARG_PIER_TYPE].value is None:
            self.fail('Must select a Pier type when the coverage has pier arcs.')
        for arc in arcs:
            self.add_pier_data(arc, arguments)
        arc_data.extend(arcs)

        # remove shapely line string and length
        for arc in arc_data:
            arc.pop('line_string')
            arc.pop('length')
        return arc_data

    def add_pier_data(self, arc_data, arguments):
        """Add the pier data to the arc.

        Args:
            arc_data (dict): The arc data.
            arguments (List[Argument]): The arguments.

        Returns:
            (dict): The default arc data.
        """
        arc_data['arc_type'] = ArcType.PIER
        arc_data['pier_type'] = PierType(arguments[self.ARG_PIER_TYPE].value)
        if arc_data['pier_type'] == PierType.GROUP:
            arc_data['pier_size'] = arguments[self.ARG_PIER_DIAMETER].value
            arc_data['pier_element_wrap_width'] = arguments[self.ARG_GROUP_WRAPPING_WIDTH].value
            arc_data['number_piers'] = arguments[self.ARG_PIER_GROUP_COUNT].value
            arc_data['pier_spacing'] = arguments[self.ARG_PIER_GROUP_SPACING].value
        elif arc_data['pier_type'] == PierType.WALL:
            arc_data['pier_size'] = arguments[self.ARG_WALL_WIDTH].value
            arc_data['pier_element_wrap_width'] = arguments[self.ARG_WALL_WRAPPING_WIDTH].value
            arc_data['pier_length'] = arguments[self.ARG_WALL_PIER_LENGTH].value
            arc_data['pier_num_side_elements'] = arguments[self.ARG_SIDE_ELEMENT_COUNT].value
            arc_data['pier_end_type'] = PierEndType(arguments[self.ARG_PIER_END_TYPE].value)
        return arc_data


def _co_grid_from_ugrid(ugrid):
    """Creates a 2D constrained grid from a ugrid.

    Args:
        ugrid (xms.Grid.UGrid): the ugrid

    Returns:
        (xms.constraint.Grid): the co_grid
    """
    co_builder = UGridBuilder()
    co_builder.set_is_2d()
    co_builder.set_ugrid(ugrid)
    return co_builder.build_grid()

# def main():
#     """Main function, for testing."""
#     from xms.tool_gui.tool_dialog import ToolDialog
#     from xms.guipy.dialogs.xms_parent_dlg import ensure_qapplication_exists
#
#     qapp = ensure_qapplication_exists()
#     tool = BridgeFootprintTool()
#     arguments = tool.initial_arguments()
#     tool_dialog = ToolDialog(None, arguments, 'Tool', tool=tool)
#     if tool_dialog.exec():
#         tool.run_tool(tool_dialog.tool_arguments)
#
#
# if __name__ == "__main__":
#     main()
