Skip to content

Commit 4f7805b

Browse files
committed
refactor: gdino tools
1 parent ec1149b commit 4f7805b

File tree

2 files changed

+6
-17
lines changed

2 files changed

+6
-17
lines changed

src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,10 @@
1616

1717
import numpy as np
1818
import sensor_msgs.msg
19-
from langchain_core.tools import BaseTool
2019
from pydantic import BaseModel, Field
21-
from rai.communication.ros2 import ROS2Connector
2220
from rai.communication.ros2.api import convert_ros_img_to_ndarray
2321
from rai.communication.ros2.ros_async import get_future_result
22+
from rai.tools.ros2.base import BaseROS2Tool
2423
from rclpy.exceptions import (
2524
ParameterNotDeclaredException,
2625
ParameterUninitializedException,
@@ -78,9 +77,7 @@ class DistanceMeasurement(NamedTuple):
7877

7978

8079
# --------------------- Tools ---------------------
81-
class GroundingDinoBaseTool(BaseTool):
82-
connector: ROS2Connector = Field(..., exclude=True)
83-
80+
class GroundingDinoBaseTool(BaseROS2Tool):
8481
box_threshold: float = Field(default=0.35, description="Box threshold for GDINO")
8582
text_threshold: float = Field(default=0.45, description="Text threshold for GDINO")
8683

@@ -89,7 +86,7 @@ def _call_gdino_node(
8986
) -> Future:
9087
cli = self.connector.node.create_client(RAIGroundingDino, GDINO_SERVICE_NAME)
9188
while not cli.wait_for_service(timeout_sec=1.0):
92-
self.node.get_logger().info(
89+
self.connector.node.get_logger().info(
9390
f"service {GDINO_SERVICE_NAME} not available, waiting again..."
9491
)
9592
req = RAIGroundingDino.Request()

src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,13 @@
1818
import numpy as np
1919
import rclpy
2020
import sensor_msgs.msg
21-
from langchain_core.tools import BaseTool
2221
from pydantic import BaseModel, Field
2322
from rai.communication.ros2.api import (
2423
convert_ros_img_to_base64,
2524
convert_ros_img_to_ndarray,
2625
)
27-
from rai.communication.ros2.connectors import ROS2Connector
2826
from rai.communication.ros2.ros_async import get_future_result
27+
from rai.tools.ros2.base import BaseROS2Tool
2928
from rclpy import Future
3029
from rclpy.exceptions import (
3130
ParameterNotDeclaredException,
@@ -67,12 +66,7 @@ class GetGrabbingPointInput(BaseModel):
6766

6867

6968
# --------------------- Tools ---------------------
70-
class GetSegmentationTool:
71-
connector: ROS2Connector = Field(..., exclude=True)
72-
73-
name: str = ""
74-
description: str = ""
75-
69+
class GetSegmentationTool(BaseROS2Tool):
7670
box_threshold: float = Field(default=0.35, description="Box threshold for GDINO")
7771
text_threshold: float = Field(default=0.45, description="Text threshold for GDINO")
7872

@@ -194,9 +188,7 @@ def depth_to_point_cloud(
194188
return points
195189

196190

197-
class GetGrabbingPointTool(BaseTool):
198-
connector: ROS2Connector = Field(..., exclude=True)
199-
191+
class GetGrabbingPointTool(BaseROS2Tool):
200192
name: str = "GetGrabbingPointTool"
201193
description: str = "Get the grabbing point of an object"
202194
pcd: List[Any] = []

0 commit comments

Comments
 (0)