11#!/usr/bin/env python3
22
3+ from datetime import datetime
4+ from functools import partial
35import json
46import os
7+ from pathlib import Path , PurePath
58import time
6- from datetime import datetime
7- from functools import partial
89
10+ from absl import app , flags , logging
11+ import click
912import cv2
1013import flax
1114import imageio
1215import jax
1316import jax .numpy as jnp
1417import numpy as np
1518import tensorflow as tf
16- from absl import app , flags , logging
1719
1820# bridge_data_robot imports
19- from widowx_envs .widowx_env_service import WidowXClient , WidowXStatus , WidowXConfigs
21+ from widowx_envs .widowx_env_service import WidowXClient , WidowXConfigs , WidowXStatus
22+ from widowx_wrapper import convert_obs , state_to_eep , wait_for_obs , WidowXGym
2023
24+ from orca .utils .gym_wrappers import HistoryWrapper , RHCWrapper , TemporalEnsembleWrapper
2125from orca .utils .pretrained_utils import PretrainedModel
22- from widowx_wrapper import WidowXGym , convert_obs , wait_for_obs , state_to_eep
23- from orca .utils .gym_wrappers import (
24- HistoryWrapper ,
25- RHCWrapper ,
26- TemporalEnsembleWrapper ,
27- )
2826
2927np .set_printoptions (suppress = True )
3028
3533flags .DEFINE_multi_string (
3634 "checkpoint_weights_path" , None , "Path to checkpoint" , required = True
3735)
38- flags .DEFINE_multi_string (
39- "checkpoint_step " , None , "Checkpoint step" , required = True
40- )
41- flags . DEFINE_multi_string (
42- "checkpoint_config_path" , None , "Path to checkpoint config JSON" , required = True
43- )
44- flags . DEFINE_multi_string (
45- "checkpoint_metadata_path" , None , "Path to checkpoint metadata JSON" , required = True
36+ flags .DEFINE_multi_integer ( "checkpoint_step" , None , "Checkpoint step" , required = True )
37+ flags . DEFINE_bool ( "add_jaxrlm_baseline " , False , "Also compare to jaxrl_m baseline" )
38+
39+
40+ flags . DEFINE_string (
41+ "checkpoint_cache_dir" ,
42+ "/tmp/" ,
43+ "Where to cache checkpoints downloaded from GCS" ,
4644)
47- flags .DEFINE_multi_string (
48- "checkpoint_example_batch_path" ,
49- None ,
50- "Path to checkpoint metadata JSON" ,
51- required = True ,
45+ flags .DEFINE_string (
46+ "modality" , "" , "Either 'g', 'goal', 'l', 'language' (leave empty to prompt when running)"
5247)
48+
5349flags .DEFINE_integer ("im_size" , None , "Image size" , required = True )
5450flags .DEFINE_string ("video_save_path" , None , "Path to save video" )
5551flags .DEFINE_integer ("num_timesteps" , 120 , "num timesteps" )
6763# show image flag
6864flags .DEFINE_bool ("show_image" , False , "Show image" )
6965
66+
7067##############################################################################
7168
69+ STEP_DURATION_MESSAGE = """
70+ Bridge data was collected with non-blocking control and a step duration of 0.2s.
71+ However, we relabel the actions to make it look like the data was collected with blocking control and we evaluate with blocking control.
72+ We also use a step duration of 0.4s to reduce the jerkiness of the policy.
73+ Be sure to change the step duration back to 0.2 if evaluating with non-blocking control.
74+ """
7275STEP_DURATION = 0.4
7376STICKY_GRIPPER_NUM_STEPS = 1
7477WORKSPACE_BOUNDS = [[0.1 , - 0.15 , - 0.01 , - 1.57 , 0 ], [0.45 , 0.25 , 0.25 , 1.57 , 0 ]]
8285##############################################################################
8386
8487
88+ def maybe_download_checkpoint_from_gcs (cloud_path , step , save_path ):
89+ if not cloud_path .startswith ("gs://" ):
90+ return cloud_path , step # Actually on the local filesystem
91+
92+ checkpoint_path = tf .io .gfile .join (cloud_path , f"{ step } " )
93+ norm_path = tf .io .gfile .join (cloud_path , "action_proprio*" )
94+ config_path = tf .io .gfile .join (cloud_path , "config.json*" )
95+ example_batch_path = tf .io .gfile .join (cloud_path , "example_batch.msgpack*" )
96+
97+ run_name = Path (cloud_path ).name
98+ save_path = os .path .join (save_path , run_name )
99+
100+ target_checkpoint_path = os .path .join (save_path , f"{ step } " )
101+ if os .path .exists (target_checkpoint_path ):
102+ logging .warning (
103+ "Checkpoint already exists at %s, skipping download" , target_checkpoint_path
104+ )
105+ return save_path , step
106+ os .makedirs (save_path , exist_ok = True )
107+ logging .warning ("Downloading checkpoint and metadata to %s" , save_path )
108+
109+ os .system (f"gsutil cp -r { checkpoint_path } { save_path } /" )
110+ os .system (f"gsutil cp { norm_path } { save_path } /" )
111+ os .system (f"gsutil cp { config_path } { save_path } /" )
112+ os .system (f"gsutil cp { example_batch_path } { save_path } /" )
113+
114+ return save_path , step
115+
116+
85117def supply_rng (f , rng = jax .random .PRNGKey (0 )):
86118 def wrapped (* args , ** kwargs ):
87119 nonlocal rng
@@ -120,11 +152,42 @@ def sample_actions(
120152 return actions [0 ] * std + mean
121153
122154
123- def load_checkpoint (weights_path , config_path , metadata_path , example_batch_path , step ):
124- model = PretrainedModel .load_pretrained (
125- weights_path , config_path , example_batch_path , step
126- )
155+ def load_jaxrlm_checkpoint (
156+ weights_path = "/mount/harddrive/homer/bridgev2_packaged/bridgev2policies/gcbc_256/checkpoint_300000/" ,
157+ config_path = "/mount/harddrive/homer/bridgev2_packaged/bridgev2policies/gcbc_256/gcbc_256_config.json" ,
158+ code_path = "/mount/harddrive/homer/bridgev2_packaged/bridgev2policies/bridge_data_v2.zip" ,
159+ ):
160+ from codesave import UniqueCodebase
161+
162+ with UniqueCodebase (code_path ) as cs :
163+ pretrained_utils = cs .import_module ("jaxrl_m.pretrained_utils" )
164+ loaded = pretrained_utils .load_checkpoint (
165+ weights_path , config_path , im_size = 256
166+ )
167+ # loaded contains: {
168+ # "agent": jaxrlm Agent,
169+ # "policy_fn": callable taking in observation and goal inputs and outputs **unnormalized** actions,
170+ # "normalization_stats": {"action": {"mean": [7], "std": [7]}}
171+ # "obs_horizon": int
172+ # }
173+
174+ class Dummy :
175+ def create_tasks (self , goals ):
176+ return goals .copy ()
177+
178+ def new_policy_fn (observations , goals ):
179+ observations = {"image" : observations ["image_0" ]}
180+ goals = {"image" : goals ["image_0" ]}
181+ return loaded ["policy_fn" ](observations , goals )
182+
183+ return new_policy_fn , Dummy ()
127184
185+
186+ def load_checkpoint (weights_path , step ):
187+ model = PretrainedModel .load_pretrained (weights_path , step = int (step ))
188+ metadata_path = os .path .join (
189+ weights_path , "action_proprio_metadata_bridge_dataset.json"
190+ )
128191 with open (metadata_path , "r" ) as f :
129192 action_proprio_metadata = json .load (f )
130193 action_mean = jnp .array (action_proprio_metadata ["action" ]["mean" ])
@@ -144,38 +207,31 @@ def load_checkpoint(weights_path, config_path, metadata_path, example_batch_path
144207
145208
146209def main (_ ):
147- assert (
148- len (FLAGS .checkpoint_weights_path )
149- == len (FLAGS .checkpoint_config_path )
150- == len (FLAGS .checkpoint_metadata_path )
151- == len (FLAGS .checkpoint_example_batch_path )
152- == len (FLAGS .checkpoint_step )
153- )
210+ assert len (FLAGS .checkpoint_weights_path ) == len (FLAGS .checkpoint_step )
211+ FLAGS .modality = FLAGS .modality [:1 ]
212+ assert FLAGS .modality in ["g" , "l" , "" ]
213+ if not FLAGS .blocking :
214+ assert STEP_DURATION == 0.2 , STEP_DURATION_MESSAGE
154215
155216 # policies is a dict from run_name to policy function
156217 policies = {}
157- for (
158- checkpoint_weights_path ,
159- checkpoint_config_path ,
160- checkpoint_metadata_path ,
161- checkpoint_example_batch_path ,
162- checkpoint_step ,
163- ) in zip (
218+ for (checkpoint_weights_path , checkpoint_step ,) in zip (
164219 FLAGS .checkpoint_weights_path ,
165- FLAGS .checkpoint_config_path ,
166- FLAGS .checkpoint_metadata_path ,
167- FLAGS .checkpoint_example_batch_path ,
168220 FLAGS .checkpoint_step ,
169221 ):
222+ checkpoint_weights_path , checkpoint_step = maybe_download_checkpoint_from_gcs (
223+ checkpoint_weights_path ,
224+ checkpoint_step ,
225+ FLAGS .checkpoint_cache_dir ,
226+ )
170227 assert tf .io .gfile .exists (checkpoint_weights_path ), checkpoint_weights_path
171- run_name = checkpoint_config_path . split ("/" )[- 2 ]
228+ run_name = checkpoint_weights_path . rpartition ("/" )[2 ]
172229 policies [f"{ run_name } -{ checkpoint_step } " ] = load_checkpoint (
173230 checkpoint_weights_path ,
174- checkpoint_config_path ,
175- checkpoint_metadata_path ,
176- checkpoint_example_batch_path ,
177- checkpoint_step
231+ checkpoint_step ,
178232 )
233+ if FLAGS .add_jaxrlm_baseline :
234+ policies ["jaxrl_gcbc" ] = load_jaxrlm_checkpoint ()
179235
180236 if FLAGS .initial_eep is not None :
181237 assert isinstance (FLAGS .initial_eep , list )
@@ -197,9 +253,8 @@ def main(_):
197253 # env = TemporalEnsembleWrapper(env, FLAGS.pred_horizon)
198254 env = RHCWrapper (env , FLAGS .pred_horizon , FLAGS .exec_horizon )
199255
200- task = {
201- "image_0" : jnp .zeros ((FLAGS .im_size , FLAGS .im_size , 3 ), dtype = np .uint8 ),
202- }
256+ goal_image = jnp .zeros ((FLAGS .im_size , FLAGS .im_size , 3 ), dtype = np .uint8 )
257+ goal_instruction = ""
203258
204259 # goal sampling loop
205260 while True :
@@ -211,21 +266,20 @@ def main(_):
211266 print ("policies:" )
212267 for i , name in enumerate (policies .keys ()):
213268 print (f"{ i } ) { name } " )
214- policy_idx = int ( input ( "select policy: " ) )
269+ policy_idx = click . prompt ( "Select policy" , type = int )
215270
216271 policy_name = list (policies .keys ())[policy_idx ]
217272 policy_fn , model = policies [policy_name ]
218273 model : PretrainedModel # type hinting
219274
220- modality = input ("Language or goal image? [l/g]" )
275+ modality = FLAGS .modality
276+ if not modality :
277+ modality = click .prompt (
278+ "Language or goal image?" , type = click .Choice (["l" , "g" ])
279+ )
280+
221281 if modality == "g" :
222- # ask for new goal
223- if task ["image_0" ] is None :
224- print ("Taking a new goal..." )
225- ch = "y"
226- else :
227- ch = input ("Take a new goal? [y/n]" )
228- if ch == "y" :
282+ if click .confirm ("Take a new goal?" , default = True ):
229283 assert isinstance (FLAGS .goal_eep , list )
230284 _eep = [float (e ) for e in FLAGS .goal_eep ]
231285 goal_eep = state_to_eep (_eep , 0 )
@@ -237,17 +291,21 @@ def main(_):
237291
238292 input ("Press [Enter] when ready for taking the goal image. " )
239293 obs = wait_for_obs (widowx_client )
240- goals = jax .tree_map (lambda x : x [None ], convert_obs (obs , FLAGS .im_size ))
241- task = model .create_tasks (goals = goals )
242- else :
243- # ask for new instruction
244- if "language_instruction" not in task or ["language_instruction" ] is None :
245- ch = "y"
246- else :
247- ch = input ("New instruction? [y/n]" )
248- if ch == "y" :
294+ goal = jax .tree_map (lambda x : x [None ], convert_obs (obs , FLAGS .im_size ))
295+
296+ task = model .create_tasks (goals = goal )
297+ goal_image = goal ["image_0" ][0 ]
298+ goal_instruction = ""
299+ elif modality == "l" :
300+ print ("Current instruction: " , goal_instruction )
301+ if click .confirm ("Take a new instruction?" , default = True ):
249302 text = input ("Instruction?" )
250- task = model .create_tasks (text = [text ])
303+
304+ task = model .create_tasks (text = [text ])
305+ goal_instruction = text
306+ goal_image = jnp .zeros_like (goal_image )
307+ else :
308+ raise NotImplementedError ()
251309
252310 input ("Press [Enter] to start." )
253311
@@ -267,7 +325,7 @@ def main(_):
267325
268326 # save images
269327 images .append (obs ["image_0" ][- 1 ])
270- goals .append (task [ "image_0" ][ 0 ] )
328+ goals .append (goal_image )
271329
272330 if FLAGS .show_image :
273331 bgr_img = cv2 .cvtColor (obs ["full_image" ][- 1 ], cv2 .COLOR_RGB2BGR )
0 commit comments