Skip to content

Commit eb403d9

Browse files
committed
[Tutorial] Beam search with GPT models
ghstack-source-id: 396baef4490d010cf55171280d6382257a25577f Pull Request resolved: #2623
1 parent 2511c04 commit eb403d9

File tree

20 files changed

+945
-72
lines changed

20 files changed

+945
-72
lines changed

Diff for: docs/requirements.txt

+4
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,7 @@ vmas
2828
onnxscript
2929
onnxruntime
3030
onnx
31+
plotly
32+
igraph
33+
transformers
34+
datasets

Diff for: docs/source/_static/img/rollout-llm.png

318 KB
Loading

Diff for: docs/source/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ Intermediate
105105
tutorials/dqn_with_rnn
106106
tutorials/rb_tutorial
107107
tutorials/export
108+
tutorials/beam_search_with_gpt
108109

109110
Advanced
110111
--------

Diff for: docs/source/reference/envs.rst

+2
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,8 @@ TorchRL offers a series of custom built-in environments.
347347

348348
PendulumEnv
349349
TicTacToeEnv
350+
LLMHashingEnv
351+
350352

351353
Multi-agent environments
352354
------------------------

Diff for: docs/source/reference/trainers.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ Hooks can be split into 3 categories: **data processing** (``"batch_process"`` a
7979

8080
- **Logging** hooks take a batch of data presented as a ``TensorDict`` and write in the logger
8181
some information retrieved from that data. Examples include the ``LogValidationReward`` hook, the reward
82-
logger (``LogScaler``) and such. Hooks should return a dictionary (or a None value) containing the
82+
logger (``LogScalar``) and such. Hooks should return a dictionary (or a None value) containing the
8383
data to log. The key ``"log_pbar"`` is reserved to boolean values indicating if the logged value
8484
should be displayed on the progression bar printed on the training log.
8585

@@ -174,7 +174,7 @@ Trainer and hooks
174174
BatchSubSampler
175175
ClearCudaCache
176176
CountFramesLog
177-
LogScaler
177+
LogScalar
178178
OptimizerHook
179179
LogValidationReward
180180
ReplayBufferTrainer

Diff for: test/mocking_classes.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -1931,14 +1931,18 @@ def __init__(self):
19311931
tensor=Unbounded(3),
19321932
non_tensor=NonTensor(shape=()),
19331933
)
1934+
self._saved_obs_spec = self.observation_spec.clone()
19341935
self.state_spec = Composite(
19351936
non_tensor=NonTensor(shape=()),
19361937
)
1938+
self._saved_state_spec = self.state_spec.clone()
19371939
self.reward_spec = Unbounded(1)
1940+
self._saved_full_reward_spec = self.full_reward_spec.clone()
19381941
self.action_spec = Unbounded(1)
1942+
self._saved_full_action_spec = self.full_action_spec.clone()
19391943

19401944
def _reset(self, tensordict):
1941-
data = self.observation_spec.zero()
1945+
data = self._saved_obs_spec.zero()
19421946
data.set_non_tensor("non_tensor", 0)
19431947
data.update(self.full_done_spec.zero())
19441948
return data
@@ -1947,10 +1951,10 @@ def _step(
19471951
self,
19481952
tensordict: TensorDictBase,
19491953
) -> TensorDictBase:
1950-
data = self.observation_spec.zero()
1954+
data = self._saved_obs_spec.zero()
19511955
data.set_non_tensor("non_tensor", tensordict["non_tensor"] + 1)
19521956
data.update(self.full_done_spec.zero())
1953-
data.update(self.full_reward_spec.zero())
1957+
data.update(self._saved_full_reward_spec.zero())
19541958
return data
19551959

19561960
def _set_seed(self, seed: Optional[int]):

Diff for: test/test_env.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -3528,8 +3528,13 @@ def test_single_env_spec():
35283528
assert env.input_spec.is_in(env.input_spec_unbatched.zeros(env.shape))
35293529

35303530

3531-
def test_auto_spec():
3532-
env = CountingEnv()
3531+
@pytest.mark.parametrize("env_type", [CountingEnv, EnvWithMetadata])
3532+
def test_auto_spec(env_type):
3533+
if env_type is EnvWithMetadata:
3534+
obs_vals = ["tensor", "non_tensor"]
3535+
else:
3536+
obs_vals = "observation"
3537+
env = env_type()
35333538
td = env.reset()
35343539

35353540
policy = lambda td, action_spec=env.full_action_spec.clone(): td.update(
@@ -3552,7 +3557,7 @@ def test_auto_spec():
35523557
shape=env.full_state_spec.shape, device=env.full_state_spec.device
35533558
)
35543559
env._action_keys = ["action"]
3555-
env.auto_specs_(policy, tensordict=td.copy())
3560+
env.auto_specs_(policy, tensordict=td.copy(), observation_key=obs_vals)
35563561
env.check_env_specs(tensordict=td.copy())
35573562

35583563

Diff for: torchrl/_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,7 @@ def _can_be_pickled(obj):
829829
def _make_ordinal_device(device: torch.device):
830830
if device is None:
831831
return device
832+
device = torch.device(device)
832833
if device.type == "cuda" and device.index is None:
833834
return torch.device("cuda", index=torch.cuda.current_device())
834835
if device.type == "mps" and device.index is None:

Diff for: torchrl/data/map/hash.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ def forward(self, features: torch.Tensor) -> torch.Tensor:
7575
class SipHash(Module):
7676
"""A Module to Compute SipHash values for given tensors.
7777
78-
A hash function module based on SipHash implementation in python.
78+
A hash function module based on SipHash implementation in python. Input tensors should have shape ``[batch_size, num_features]``
79+
and the output shape will be ``[batch_size]``.
7980
8081
Args:
8182
as_tensor (bool, optional): if ``True``, the bytes will be turned into integers

Diff for: torchrl/data/map/tdstorage.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def from_tensordict_pair(
177177
collate_fn: Callable[[Any], Any] | None = None,
178178
write_fn: Callable[[Any, Any], Any] | None = None,
179179
consolidated: bool | None = None,
180-
):
180+
) -> TensorDictMap:
181181
"""Creates a new TensorDictStorage from a pair of tensordicts (source and dest) using pre-defined rules of thumb.
182182
183183
Args:
@@ -308,7 +308,23 @@ def __setitem__(self, item: TensorDictBase, value: TensorDictBase):
308308
if not self._has_lazy_out_keys():
309309
# TODO: make this work with pytrees and avoid calling select if keys match
310310
value = value.select(*self.out_keys, strict=False)
311+
item, value = self._maybe_add_batch(item, value)
312+
index = self._to_index(item, extend=True)
313+
if index.unique().numel() < index.numel():
314+
# If multiple values point to the same place in the storage, we cannot process them by batch
315+
# There could be a better way to deal with this, using unique ids.
316+
vals = []
317+
for it, val in zip(item.split(1), value.split(1)):
318+
self[it] = val
319+
vals.append(val)
320+
# __setitem__ may affect the content of the input data
321+
value.update(TensorDictBase.lazy_stack(vals))
322+
return
311323
if self.write_fn is not None:
324+
# We use this block in the following context: the value written in the storage is already present,
325+
# but it needs to be updated.
326+
# We first check if the value is already there using `contains`. If so, we pass the new value and the
327+
# previous one to write_fn. The values that are not present are passed alone.
312328
if len(self):
313329
modifiable = self.contains(item)
314330
if modifiable.any():
@@ -322,8 +338,6 @@ def __setitem__(self, item: TensorDictBase, value: TensorDictBase):
322338
value = self.write_fn(value)
323339
else:
324340
value = self.write_fn(value)
325-
item, value = self._maybe_add_batch(item, value)
326-
index = self._to_index(item, extend=True)
327341
self.storage.set(index, value)
328342

329343
def __len__(self):

0 commit comments

Comments
 (0)