Skip to content

Commit 2d43296

Browse files
authored
Add online learning pipeline for action selection (#153)
* [WIP] Add depth image to segment_from_point for use in context calculation * Add empty data folder * Revert "Add empty data folder" This reverts commit 30b3819. * Added way to set data directory * [WIP] Adding spanet + SAM context adapters + HapticNet posthoc adapter * [WIP] Added SPANet ContextAdapter * [WIP] Added HapticNet * [WIP] Added hapticnet and spanet adapters * [WIP] Finished adapters, updated Policy interface for checkpoints, next step: implement checkpointing and data record * [WIP] Added basic linear policy * [WIP] Added all linear policies * [WIP] Finish adding checkpointing and data record, UNTESTED * [WIP] Initial CoRL tuning for 0.3rad/s angular speed * Tested on real robot, did not yet test reports * Switch to cartesian controller instead of MoveIt Servo for acquisition * Return items required for AcquisitionReport * Register blackboard actions, tested in sim * Delete existing symlink in set_data_folder if exists * Stressed test in real: working flip food frame upon MoveAbove/MoveInto planning failure * Add debug for AcquisitionSelect and AcquisitionReport, plus flip food frame by default for Action 1 * Black format * Added ability to pass in constant action * Add constant policy override * Add launch file action override
1 parent a124116 commit 2d43296

32 files changed

+2605
-610
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ __pycache__/
3333

3434
# Temp files
3535
*~
36+
*.bak
3637

3738
# VIM
3839
[._]*.s[a-v][a-z]
@@ -64,3 +65,4 @@ trees/debug.xml
6465
*.pth
6566
ada_feeding_perception/test/food_img/
6667
ada_feeding_perception/test/output/
68+
ada_feeding_action_select/data

.pylintrc

+1-1
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,7 @@ contextmanager-decorators=contextlib.contextmanager
580580
# List of members which are set dynamically and missed by pylint inference
581581
# system, and so shouldn't trigger E1101 when accessed. Python regular
582582
# expressions are accepted.
583-
generated-members=cv2
583+
generated-members=cv2, torch.*
584584

585585
# Tells whether to warn about missing members when the owner of the attribute
586586
# is inferred to be None.

ada_feeding/ada_feeding/behaviors/acquisition/compute_action_constraints.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def update(self) -> py_trees.common.Status:
324324
linear_stamped, self.moveit2.base_link_name
325325
).vector
326326

327-
### Move Angular to Approach Frame
327+
### Move Angular to Base Link Frame
328328
# Get TF EE frame -> base link frame
329329
if not self.tf_buffer.can_transform(
330330
self.moveit2.base_link_name,

ada_feeding/ada_feeding/behaviors/acquisition/compute_food_frame.py

+5
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def blackboard_inputs(
5353
timestamp: Union[BlackboardKey, rclpy.time.Time] = rclpy.time.Time(),
5454
food_frame_id: Union[BlackboardKey, str] = "food",
5555
world_frame: Union[BlackboardKey, str] = "world",
56+
flip_food_frame: Union[BlackboardKey, bool] = False,
5657
) -> None:
5758
"""
5859
Blackboard Inputs
@@ -66,6 +67,7 @@ def blackboard_inputs(
6667
food_frame_id (string): If len>0, TF frame to publish static transform
6768
(relative to world_frame)
6869
world_frame (string): ID of the TF frame to represent the food frame in
70+
flip_food_frame (bool): whether to rotate the food frame 180 about Z
6971
"""
7072
# pylint: disable=unused-argument, duplicate-code
7173
# Arguments are handled generically in base class.
@@ -264,6 +266,9 @@ def update(self) -> py_trees.common.Status:
264266
point2 = pyrealsense2.rs2_deproject_pixel_to_point(
265267
self.intrinsics, [point2[0], point2[1]], mask.average_depth
266268
)
269+
# Flip X if requested
270+
if self.blackboard_get("flip_food_frame"):
271+
point1, point2 = point2, point1
267272
x_pos = Vector3Stamped()
268273
x_pos.header.frame_id = camera_frame
269274
x_pos.vector.x = point1[0] - point2[0]

ada_feeding/ada_feeding/trees/acquire_food_tree.py

+186-101
Large diffs are not rendered by default.

ada_feeding/ada_feeding/trees/start_servo_tree.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(
3333
node: Node,
3434
servo_controller_name: str = "jaco_arm_servo_controller",
3535
move_group_controller_name: str = "jaco_arm_controller",
36+
start_moveit_servo: bool = True,
3637
) -> None:
3738
"""
3839
Initializes the behavior tree.
@@ -47,6 +48,7 @@ def __init__(
4748
super().__init__(node=node)
4849
self.servo_controller_name = servo_controller_name
4950
self.move_group_controller_name = move_group_controller_name
51+
self.start_moveit_servo = start_moveit_servo
5052

5153
@override
5254
def create_tree(
@@ -100,16 +102,16 @@ def create_tree(
100102
)
101103
],
102104
)
105+
children = [switch_controllers]
106+
if self.start_moveit_servo:
107+
children.append(start_servo)
103108

104109
# Put them together in a sequence
105110
# pylint: disable=duplicate-code
106111
return py_trees.trees.BehaviourTree(
107112
root=py_trees.composites.Sequence(
108113
name=name,
109114
memory=True,
110-
children=[
111-
switch_controllers,
112-
start_servo,
113-
],
115+
children=children,
114116
)
115117
)

ada_feeding/ada_feeding/trees/stop_servo_tree.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,14 @@ class StopServoTree(TriggerTree):
3636
4. Calls the `~/switch_controller` service to turn off the servo controller.
3737
"""
3838

39+
# pylint: disable=too-many-arguments
3940
def __init__(
4041
self,
4142
node: Node,
4243
base_frame_id: str = "j2n6s200_link_base",
4344
servo_controller_name: str = "jaco_arm_servo_controller",
4445
delay: float = 0.5,
46+
stop_moveit_servo: bool = True,
4547
) -> None:
4648
"""
4749
Initializes the behavior tree.
@@ -57,6 +59,7 @@ def __init__(
5759
self.base_frame_id = base_frame_id
5860
self.servo_controller_name = servo_controller_name
5961
self.delay = delay
62+
self.stop_moveit_servo = stop_moveit_servo
6063

6164
@override
6265
def create_tree(
@@ -156,16 +159,18 @@ def create_tree(
156159

157160
# Put them together in a sequence with memory
158161
# pylint: disable=duplicate-code
162+
children = [
163+
update_timestamp,
164+
twist_pub,
165+
delay_behavior,
166+
]
167+
if self.stop_moveit_servo:
168+
children.append(stop_servo)
169+
children.append(stop_controllers)
159170
return py_trees.trees.BehaviourTree(
160171
root=py_trees.composites.Sequence(
161172
name=name,
162173
memory=True,
163-
children=[
164-
update_timestamp,
165-
twist_pub,
166-
delay_behavior,
167-
stop_servo,
168-
stop_controllers,
169-
],
174+
children=children,
170175
)
171176
)

ada_feeding/config/ada_feeding_action_servers_current.yaml

+5-1
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,9 @@
22
ada_feeding_action_servers:
33
ros__parameters:
44
current:
5+
MoveToMouth.tree_kwargs.plan_distance_from_mouth:
6+
- 0.025
7+
- 0.0
8+
- -0.01
59
overridden_parameters:
6-
- ""
10+
- MoveToMouth.tree_kwargs.plan_distance_from_mouth

ada_feeding/launch/ada_feeding_launch.xml

+9-5
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
<arg name="run_web_bridge" default="false" description="Whether to run the web bridge nodes" />
33
<arg name="use_estop" default="true" description="Whether to use the e-stop. Should only be set false in sim, since we don't have a way of simulating the e-stop button." />
44
<arg name="log_level" default="info" description="Log Level to pass to create_action_servers: debug, info, warn" />
5-
<arg name="default_action_select" default="true" description="Whether to use the default action select: constant policy 0" />
5+
<arg name="policy" default="constant" description="Which policy to use" />
6+
<arg name="action" default="0" description="Which action to use with constant policy" />
67

78
<group if="$(var run_web_bridge)">
89
<!-- The ROSBridge Node -->
@@ -26,22 +27,25 @@
2627
<remap from="~/ft_topic" to="/wireless_ft/ftSensor1" />
2728
<remap from="~/set_force_gate_controller_parameters" to="/jaco_arm_controller/set_parameters" />
2829
<remap from="~/set_servo_controller_parameters" to="/jaco_arm_servo_controller/set_parameters" />
30+
<remap from="~/set_cartesian_controller_parameters" to="/jaco_arm_cartesian_controller/set_parameters" />
2931
<remap from="~/clear_octomap" to="/clear_octomap" />
3032
<remap from="~/toggle_face_detection" to="/toggle_face_detection" />
3133
<remap from="~/face_detection" to="/face_detection" />
3234
<remap from="~/switch_controller" to="/controller_manager/switch_controller" />
3335
<remap from="~/start_servo" to="/servo_node/start_servo" />
3436
<remap from="~/servo_twist_cmds" to="/servo_node/delta_twist_cmds" />
37+
<remap from="~/cartesian_twist_cmds" to="/jaco_arm_cartesian_controller/twist_cmd" />
3538
<remap from="~/servo_status" to="/servo_node/status" />
3639
<remap from="~/stop_servo" to="/servo_node/stop_servo" />
3740
<remap from="~/action_select" to="/ada_feeding_action_select/action_select" />
3841
<remap from="~/action_report" to="/ada_feeding_action_select/action_report" />
3942
</node>
4043

41-
<!-- Include Default Action Selection Server -->
42-
<group if="$(var default_action_select)">
43-
<include file="$(find-pkg-share ada_feeding_action_select)/launch/ada_feeding_action_select_launch.xml"/>
44-
</group>
44+
<!-- Include Action Selection Server -->
45+
<include file="$(find-pkg-share ada_feeding_action_select)/launch/ada_feeding_action_select_launch.xml">
46+
<arg name="policy" value="$(var policy)" />
47+
<arg name="action" value="$(var action)" />
48+
</include>
4549

4650
<!-- Populate the planning scene -->
4751
<node pkg="ada_feeding" exec="ada_planning_scene.py" name="ada_planning_scene">

ada_feeding_action_select/ada_feeding_action_select/adapters/__init__.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
# Just returns [0]
1010
from .base_adapters import NoContext
1111

12-
# TODO: Posthoc Passthrough
12+
# TODO: SegmentAnything Context
1313

14-
# TODO: SPANet Context
14+
# SPANet Context
15+
from .spanet_adapter import SPANetContext
1516

16-
# TODO: HapticNet Posthoc
17+
# HapticNet Posthoc
18+
from .hapticnet_adapter import HapticNetPosthoc

ada_feeding_action_select/ada_feeding_action_select/adapters/base_adapters.py

+9-23
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77

88
# Standard imports
99
from abc import ABC, abstractmethod
10-
from typing import Optional
1110

1211
# Third-party imports
12+
from overrides import override
1313
import numpy as np
1414
import numpy.typing as npt
1515

@@ -22,17 +22,6 @@ class ContextAdapter(ABC):
2222
An interface to translate a visual Mask to a context vector.
2323
"""
2424

25-
def __init__(self):
26-
"""
27-
Default self properties
28-
"""
29-
30-
# These attributes are used by the policy service
31-
# To determine whether to pass image/depth data
32-
# to get_context.
33-
self.need_rgb = False
34-
self.need_depth = False
35-
3625
@property
3726
@abstractmethod
3827
def dim(self) -> int:
@@ -42,17 +31,13 @@ def dim(self) -> int:
4231
raise NotImplementedError("dimension not implemented")
4332

4433
@abstractmethod
45-
def get_context(
46-
self, mask: Mask, image: Optional[npt.NDArray], depth: Optional[npt.NDArray]
47-
) -> npt.NDArray:
34+
def get_context(self, mask: Mask) -> npt.NDArray:
4835
"""
4936
Create the context vector from the provided visual info
5037
5138
Parameters
5239
----------
5340
mask: See Mask.msg
54-
image: Full camera image, None if self.need_rgb is False
55-
image: Full depth image, None if self.need_depth is False
5641
5742
Returns
5843
-------
@@ -96,13 +81,14 @@ class NoContext(ContextAdapter, PosthocAdapter):
9681
"""
9782

9883
@property
84+
@override
9985
def dim(self) -> int:
100-
return 1
86+
return 0
10187

102-
def get_context(
103-
self, mask: Mask, image: Optional[npt.NDArray], depth: Optional[npt.NDArray]
104-
) -> npt.NDArray:
105-
return np.array([0.0])
88+
@override
89+
def get_context(self, mask: Mask) -> npt.NDArray:
90+
return np.array([])
10691

92+
@override
10793
def get_posthoc(self, data: npt.NDArray) -> npt.NDArray:
108-
return np.array([0.0])
94+
return np.array([])
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
"""
4+
This module defines the HapticNet context adapter.
5+
6+
"""
7+
8+
# Standard imports
9+
import os
10+
11+
# Third-party imports
12+
from ament_index_python.packages import get_package_share_directory
13+
import numpy as np
14+
import numpy.typing as npt
15+
from overrides import override
16+
import torch
17+
18+
# Local imports
19+
from ada_feeding_action_select.helpers import logger
20+
from .models import HapticNetConfig, HapticNet
21+
from .base_adapters import PosthocAdapter
22+
23+
24+
class HapticNetPosthoc(PosthocAdapter):
25+
"""
26+
An adapter to run force/torque data through HapticNet
27+
and extract features.
28+
"""
29+
30+
def __init__(
31+
self,
32+
checkpoint: str,
33+
n_features: int = 4,
34+
gpu_index: int = 0,
35+
) -> None:
36+
"""
37+
Load Checkpoint and Set Config Parameters
38+
39+
Parameters
40+
----------
41+
checkpoint: PTH file relative to share directory / data
42+
n_features: size of the HapticNet feature vectory (determined by checkpoint)
43+
gpu_index: which gpu to use for CUDA
44+
"""
45+
46+
# Init CUDA
47+
self.use_cuda = torch.cuda.is_available()
48+
if self.use_cuda:
49+
logger.info("Init HapticNet with CUDA")
50+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
51+
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_index)
52+
53+
# Init HapticNet
54+
self.config = HapticNetConfig(n_output=n_features)
55+
self.hapticnet = HapticNet(self.config)
56+
57+
# Load Checkpoint
58+
ckpt_file = os.path.join(
59+
get_package_share_directory("ada_feeding_action_select"), "data", checkpoint
60+
)
61+
ckpt = torch.load(ckpt_file)
62+
self.hapticnet.load_state_dict(ckpt["state_dict"])
63+
self.hapticnet.eval()
64+
if self.use_cuda:
65+
self.hapticnet = self.hapticnet.cuda()
66+
67+
@property
68+
@override
69+
def dim(self) -> int:
70+
# Docstring copied from @override
71+
# No Bias: Haptic bias apparently made it much, much worse
72+
return self.config.n_output
73+
74+
@override
75+
def get_posthoc(self, data: npt.NDArray) -> npt.NDArray:
76+
# Docstring copied from @override
77+
78+
ret = np.array([])
79+
80+
# Run data through hapticnet
81+
input_data = self.hapticnet.preprocess(data)
82+
if input_data is None:
83+
return ret
84+
if self.use_cuda:
85+
input_data = input_data.cuda()
86+
87+
# Get HapticNet Features
88+
features = self.hapticnet(input_data)
89+
90+
# Flatten, no bias
91+
ret = features.cpu().detach().numpy().flatten()
92+
93+
assert ret.size == self.dim
94+
return ret
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""
2+
This package contains code that works with the SPANet featurizer.
3+
"""
4+
5+
# SPANet Model
6+
from .spanet import SPANet, SPANetConfig
7+
from .hapticnet import HapticNet, HapticNetConfig

0 commit comments

Comments
 (0)