Skip to content

Commit 156cd9c

Browse files
authored
Merge pull request octo-models#77 from rail-berkeley/dibya-fix-bridge-eval
Updates to Bridge Evaluation
2 parents 7bac65d + 81b29ac commit 156cd9c

File tree

2 files changed

+144
-82
lines changed

2 files changed

+144
-82
lines changed

experiments/homer/bridge/eval.py

Lines changed: 130 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,28 @@
11
#!/usr/bin/env python3
22

3+
from datetime import datetime
4+
from functools import partial
35
import json
46
import os
7+
from pathlib import Path, PurePath
58
import time
6-
from datetime import datetime
7-
from functools import partial
89

10+
from absl import app, flags, logging
11+
import click
912
import cv2
1013
import flax
1114
import imageio
1215
import jax
1316
import jax.numpy as jnp
1417
import numpy as np
1518
import tensorflow as tf
16-
from absl import app, flags, logging
1719

1820
# bridge_data_robot imports
19-
from widowx_envs.widowx_env_service import WidowXClient, WidowXStatus, WidowXConfigs
21+
from widowx_envs.widowx_env_service import WidowXClient, WidowXConfigs, WidowXStatus
22+
from widowx_wrapper import convert_obs, state_to_eep, wait_for_obs, WidowXGym
2023

24+
from orca.utils.gym_wrappers import HistoryWrapper, RHCWrapper, TemporalEnsembleWrapper
2125
from orca.utils.pretrained_utils import PretrainedModel
22-
from widowx_wrapper import WidowXGym, convert_obs, wait_for_obs, state_to_eep
23-
from orca.utils.gym_wrappers import (
24-
HistoryWrapper,
25-
RHCWrapper,
26-
TemporalEnsembleWrapper,
27-
)
2826

2927
np.set_printoptions(suppress=True)
3028

@@ -35,21 +33,19 @@
3533
flags.DEFINE_multi_string(
3634
"checkpoint_weights_path", None, "Path to checkpoint", required=True
3735
)
38-
flags.DEFINE_multi_string(
39-
"checkpoint_step", None, "Checkpoint step", required=True
40-
)
41-
flags.DEFINE_multi_string(
42-
"checkpoint_config_path", None, "Path to checkpoint config JSON", required=True
43-
)
44-
flags.DEFINE_multi_string(
45-
"checkpoint_metadata_path", None, "Path to checkpoint metadata JSON", required=True
36+
flags.DEFINE_multi_integer("checkpoint_step", None, "Checkpoint step", required=True)
37+
flags.DEFINE_bool("add_jaxrlm_baseline", False, "Also compare to jaxrl_m baseline")
38+
39+
40+
flags.DEFINE_string(
41+
"checkpoint_cache_dir",
42+
"/tmp/",
43+
"Where to cache checkpoints downloaded from GCS",
4644
)
47-
flags.DEFINE_multi_string(
48-
"checkpoint_example_batch_path",
49-
None,
50-
"Path to checkpoint metadata JSON",
51-
required=True,
45+
flags.DEFINE_string(
46+
"modality", "", "Either 'g', 'goal', 'l', 'language' (leave empty to prompt when running)"
5247
)
48+
5349
flags.DEFINE_integer("im_size", None, "Image size", required=True)
5450
flags.DEFINE_string("video_save_path", None, "Path to save video")
5551
flags.DEFINE_integer("num_timesteps", 120, "num timesteps")
@@ -67,8 +63,15 @@
6763
# show image flag
6864
flags.DEFINE_bool("show_image", False, "Show image")
6965

66+
7067
##############################################################################
7168

69+
STEP_DURATION_MESSAGE = """
70+
Bridge data was collected with non-blocking control and a step duration of 0.2s.
71+
However, we relabel the actions to make it look like the data was collected with blocking control and we evaluate with blocking control.
72+
We also use a step duration of 0.4s to reduce the jerkiness of the policy.
73+
Be sure to change the step duration back to 0.2 if evaluating with non-blocking control.
74+
"""
7275
STEP_DURATION = 0.4
7376
STICKY_GRIPPER_NUM_STEPS = 1
7477
WORKSPACE_BOUNDS = [[0.1, -0.15, -0.01, -1.57, 0], [0.45, 0.25, 0.25, 1.57, 0]]
@@ -82,6 +85,35 @@
8285
##############################################################################
8386

8487

88+
def maybe_download_checkpoint_from_gcs(cloud_path, step, save_path):
89+
if not cloud_path.startswith("gs://"):
90+
return cloud_path, step # Actually on the local filesystem
91+
92+
checkpoint_path = tf.io.gfile.join(cloud_path, f"{step}")
93+
norm_path = tf.io.gfile.join(cloud_path, "action_proprio*")
94+
config_path = tf.io.gfile.join(cloud_path, "config.json*")
95+
example_batch_path = tf.io.gfile.join(cloud_path, "example_batch.msgpack*")
96+
97+
run_name = Path(cloud_path).name
98+
save_path = os.path.join(save_path, run_name)
99+
100+
target_checkpoint_path = os.path.join(save_path, f"{step}")
101+
if os.path.exists(target_checkpoint_path):
102+
logging.warning(
103+
"Checkpoint already exists at %s, skipping download", target_checkpoint_path
104+
)
105+
return save_path, step
106+
os.makedirs(save_path, exist_ok=True)
107+
logging.warning("Downloading checkpoint and metadata to %s", save_path)
108+
109+
os.system(f"gsutil cp -r {checkpoint_path} {save_path}/")
110+
os.system(f"gsutil cp {norm_path} {save_path}/")
111+
os.system(f"gsutil cp {config_path} {save_path}/")
112+
os.system(f"gsutil cp {example_batch_path} {save_path}/")
113+
114+
return save_path, step
115+
116+
85117
def supply_rng(f, rng=jax.random.PRNGKey(0)):
86118
def wrapped(*args, **kwargs):
87119
nonlocal rng
@@ -120,11 +152,42 @@ def sample_actions(
120152
return actions[0] * std + mean
121153

122154

123-
def load_checkpoint(weights_path, config_path, metadata_path, example_batch_path, step):
124-
model = PretrainedModel.load_pretrained(
125-
weights_path, config_path, example_batch_path, step
126-
)
155+
def load_jaxrlm_checkpoint(
156+
weights_path="/mount/harddrive/homer/bridgev2_packaged/bridgev2policies/gcbc_256/checkpoint_300000/",
157+
config_path="/mount/harddrive/homer/bridgev2_packaged/bridgev2policies/gcbc_256/gcbc_256_config.json",
158+
code_path="/mount/harddrive/homer/bridgev2_packaged/bridgev2policies/bridge_data_v2.zip",
159+
):
160+
from codesave import UniqueCodebase
161+
162+
with UniqueCodebase(code_path) as cs:
163+
pretrained_utils = cs.import_module("jaxrl_m.pretrained_utils")
164+
loaded = pretrained_utils.load_checkpoint(
165+
weights_path, config_path, im_size=256
166+
)
167+
# loaded contains: {
168+
# "agent": jaxrlm Agent,
169+
# "policy_fn": callable taking in observation and goal inputs and outputs **unnormalized** actions,
170+
# "normalization_stats": {"action": {"mean": [7], "std": [7]}}
171+
# "obs_horizon": int
172+
# }
173+
174+
class Dummy:
175+
def create_tasks(self, goals):
176+
return goals.copy()
177+
178+
def new_policy_fn(observations, goals):
179+
observations = {"image": observations["image_0"]}
180+
goals = {"image": goals["image_0"]}
181+
return loaded["policy_fn"](observations, goals)
182+
183+
return new_policy_fn, Dummy()
127184

185+
186+
def load_checkpoint(weights_path, step):
187+
model = PretrainedModel.load_pretrained(weights_path, step=int(step))
188+
metadata_path = os.path.join(
189+
weights_path, "action_proprio_metadata_bridge_dataset.json"
190+
)
128191
with open(metadata_path, "r") as f:
129192
action_proprio_metadata = json.load(f)
130193
action_mean = jnp.array(action_proprio_metadata["action"]["mean"])
@@ -144,38 +207,31 @@ def load_checkpoint(weights_path, config_path, metadata_path, example_batch_path
144207

145208

146209
def main(_):
147-
assert (
148-
len(FLAGS.checkpoint_weights_path)
149-
== len(FLAGS.checkpoint_config_path)
150-
== len(FLAGS.checkpoint_metadata_path)
151-
== len(FLAGS.checkpoint_example_batch_path)
152-
== len(FLAGS.checkpoint_step)
153-
)
210+
assert len(FLAGS.checkpoint_weights_path) == len(FLAGS.checkpoint_step)
211+
FLAGS.modality = FLAGS.modality[:1]
212+
assert FLAGS.modality in ["g", "l", ""]
213+
if not FLAGS.blocking:
214+
assert STEP_DURATION == 0.2, STEP_DURATION_MESSAGE
154215

155216
# policies is a dict from run_name to policy function
156217
policies = {}
157-
for (
158-
checkpoint_weights_path,
159-
checkpoint_config_path,
160-
checkpoint_metadata_path,
161-
checkpoint_example_batch_path,
162-
checkpoint_step,
163-
) in zip(
218+
for (checkpoint_weights_path, checkpoint_step,) in zip(
164219
FLAGS.checkpoint_weights_path,
165-
FLAGS.checkpoint_config_path,
166-
FLAGS.checkpoint_metadata_path,
167-
FLAGS.checkpoint_example_batch_path,
168220
FLAGS.checkpoint_step,
169221
):
222+
checkpoint_weights_path, checkpoint_step = maybe_download_checkpoint_from_gcs(
223+
checkpoint_weights_path,
224+
checkpoint_step,
225+
FLAGS.checkpoint_cache_dir,
226+
)
170227
assert tf.io.gfile.exists(checkpoint_weights_path), checkpoint_weights_path
171-
run_name = checkpoint_config_path.split("/")[-2]
228+
run_name = checkpoint_weights_path.rpartition("/")[2]
172229
policies[f"{run_name}-{checkpoint_step}"] = load_checkpoint(
173230
checkpoint_weights_path,
174-
checkpoint_config_path,
175-
checkpoint_metadata_path,
176-
checkpoint_example_batch_path,
177-
checkpoint_step
231+
checkpoint_step,
178232
)
233+
if FLAGS.add_jaxrlm_baseline:
234+
policies["jaxrl_gcbc"] = load_jaxrlm_checkpoint()
179235

180236
if FLAGS.initial_eep is not None:
181237
assert isinstance(FLAGS.initial_eep, list)
@@ -197,9 +253,8 @@ def main(_):
197253
# env = TemporalEnsembleWrapper(env, FLAGS.pred_horizon)
198254
env = RHCWrapper(env, FLAGS.pred_horizon, FLAGS.exec_horizon)
199255

200-
task = {
201-
"image_0": jnp.zeros((FLAGS.im_size, FLAGS.im_size, 3), dtype=np.uint8),
202-
}
256+
goal_image = jnp.zeros((FLAGS.im_size, FLAGS.im_size, 3), dtype=np.uint8)
257+
goal_instruction = ""
203258

204259
# goal sampling loop
205260
while True:
@@ -211,21 +266,20 @@ def main(_):
211266
print("policies:")
212267
for i, name in enumerate(policies.keys()):
213268
print(f"{i}) {name}")
214-
policy_idx = int(input("select policy: "))
269+
policy_idx = click.prompt("Select policy", type=int)
215270

216271
policy_name = list(policies.keys())[policy_idx]
217272
policy_fn, model = policies[policy_name]
218273
model: PretrainedModel # type hinting
219274

220-
modality = input("Language or goal image? [l/g]")
275+
modality = FLAGS.modality
276+
if not modality:
277+
modality = click.prompt(
278+
"Language or goal image?", type=click.Choice(["l", "g"])
279+
)
280+
221281
if modality == "g":
222-
# ask for new goal
223-
if task["image_0"] is None:
224-
print("Taking a new goal...")
225-
ch = "y"
226-
else:
227-
ch = input("Take a new goal? [y/n]")
228-
if ch == "y":
282+
if click.confirm("Take a new goal?", default=True):
229283
assert isinstance(FLAGS.goal_eep, list)
230284
_eep = [float(e) for e in FLAGS.goal_eep]
231285
goal_eep = state_to_eep(_eep, 0)
@@ -237,17 +291,21 @@ def main(_):
237291

238292
input("Press [Enter] when ready for taking the goal image. ")
239293
obs = wait_for_obs(widowx_client)
240-
goals = jax.tree_map(lambda x: x[None], convert_obs(obs, FLAGS.im_size))
241-
task = model.create_tasks(goals=goals)
242-
else:
243-
# ask for new instruction
244-
if "language_instruction" not in task or ["language_instruction"] is None:
245-
ch = "y"
246-
else:
247-
ch = input("New instruction? [y/n]")
248-
if ch == "y":
294+
goal = jax.tree_map(lambda x: x[None], convert_obs(obs, FLAGS.im_size))
295+
296+
task = model.create_tasks(goals=goal)
297+
goal_image = goal["image_0"][0]
298+
goal_instruction = ""
299+
elif modality == "l":
300+
print("Current instruction: ", goal_instruction)
301+
if click.confirm("Take a new instruction?", default=True):
249302
text = input("Instruction?")
250-
task = model.create_tasks(text=[text])
303+
304+
task = model.create_tasks(text=[text])
305+
goal_instruction = text
306+
goal_image = jnp.zeros_like(goal_image)
307+
else:
308+
raise NotImplementedError()
251309

252310
input("Press [Enter] to start.")
253311

@@ -267,7 +325,7 @@ def main(_):
267325

268326
# save images
269327
images.append(obs["image_0"][-1])
270-
goals.append(task["image_0"][0])
328+
goals.append(goal_image)
271329

272330
if FLAGS.show_image:
273331
bgr_img = cv2.cvtColor(obs["full_image"][-1], cv2.COLOR_RGB2BGR)

experiments/homer/scripts/eval.sh

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1-
NAMES=(
2-
"gc_bridge_match_old_20231026_193653"
1+
PATHS=(
2+
"gs://rail-dibya-central2/experiment_output/oxe_sweep/bridge_vits_20231111_165439"
3+
"gs://rail-dibya-central2/experiment_output/oxe_sweep/bridge_baseline_20231112_025236"
4+
"gs://rail-dibya-central2/experiment_output/oxe_sweep/bridge_jaxrlm_baseline_20231112_073307"
35
)
46

57
STEPS=(
6-
"345000"
8+
"120000"
9+
"500000"
10+
"300000"
711
)
812

9-
VIDEO_DIR="11-3"
13+
CONDITIONING_MODE="goal"
14+
VIDEO_DIR="11-12"
1015

1116
TIMESTEPS="50"
1217

@@ -21,17 +26,16 @@ EXEC_HORIZON="1"
2126
CMD="python experiments/homer/bridge/eval.py \
2227
--num_timesteps $TIMESTEPS \
2328
--video_save_path /mount/harddrive/homer/videos/$VIDEO_DIR \
24-
$(for i in "${!NAMES[@]}"; do echo "--checkpoint_weights_path /mount/harddrive/homer/checkpoints/${NAMES[$i]} "; done) \
25-
$(for i in "${!NAMES[@]}"; do echo "--checkpoint_step /mount/harddrive/homer/checkpoints/${STEPS[$i]} "; done) \
26-
$(for i in "${!NAMES[@]}"; do echo "--checkpoint_config_path /mount/harddrive/homer/checkpoints/${NAMES[$i]}/config.json "; done) \
27-
$(for i in "${!NAMES[@]}"; do echo "--checkpoint_metadata_path /mount/harddrive/homer/checkpoints/${NAMES[$i]}/action_proprio_metadata_bridge_dataset.json "; done) \
28-
$(for i in "${!NAMES[@]}"; do echo "--checkpoint_example_batch_path /mount/harddrive/homer/checkpoints/${NAMES[$i]}/example_batch.msgpack "; done) \
29+
$(for i in "${!NAMES[@]}"; do echo "--checkpoint_weights_path ${NAMES[$i]} "; done) \
30+
$(for i in "${!NAMES[@]}"; do echo "--checkpoint_step ${STEPS[$i]} "; done) \
2931
--im_size 256 \
3032
--temperature $TEMPERATURE \
3133
--horizon $HORIZON \
3234
--pred_horizon $PRED_HORIZON \
3335
--exec_horizon $EXEC_HORIZON \
34-
--blocking
36+
--blocking \
37+
--modality $CONDITIONING_MODE \
38+
--checkpoint_cache_dir /mount/harddrive/homer/checkpoints/
3539
"
3640

3741
echo $CMD

0 commit comments

Comments
 (0)