diff --git a/.ci/docker/requirements.txt b/.ci/docker/requirements.txt index a25c4494b6..bd95726bee 100644 --- a/.ci/docker/requirements.txt +++ b/.ci/docker/requirements.txt @@ -29,8 +29,8 @@ tensorboard jinja2==3.1.3 pytorch-lightning torchx -torchrl==0.7.2 -tensordict==0.7.2 +torchrl==0.9.2 +tensordict==0.9.1 # For ax_multiobjective_nas_tutorial.py ax-platform>=0.4.0,<0.5.0 nbformat>=5.9.2 diff --git a/intermediate_source/dqn_with_rnn_tutorial.py b/intermediate_source/dqn_with_rnn_tutorial.py index bcc484f0a0..462415dcc7 100644 --- a/intermediate_source/dqn_with_rnn_tutorial.py +++ b/intermediate_source/dqn_with_rnn_tutorial.py @@ -342,7 +342,9 @@ # will return a new instance of the LSTM (with shared weights) that will # assume that the input data is sequential in nature. # -policy = Seq(feature, lstm.set_recurrent_mode(True), mlp, qval) +from torchrl.modules import set_recurrent_mode + +policy = Seq(feature, lstm, mlp, qval) ###################################################################### # Because we still have a couple of uninitialized parameters we should @@ -389,7 +391,10 @@ # For the sake of efficiency, we're only running a few thousands iterations # here. In a real setting, the total number of frames should be set to 1M. # -collector = SyncDataCollector(env, stoch_policy, frames_per_batch=50, total_frames=200, device=device) + +collector = SyncDataCollector( + env, stoch_policy, frames_per_batch=50, total_frames=200, device=device +) rb = TensorDictReplayBuffer( storage=LazyMemmapStorage(20_000), batch_size=4, prefetch=10 ) @@ -422,7 +427,8 @@ rb.extend(data.unsqueeze(0).to_tensordict().cpu()) for _ in range(utd): s = rb.sample().to(device, non_blocking=True) - loss_vals = loss_fn(s) + with set_recurrent_mode(True): + loss_vals = loss_fn(s) loss_vals["loss"].backward() optim.step() optim.zero_grad() @@ -464,5 +470,5 @@ # # Further Reading # --------------- -# +# # - The TorchRL documentation can be found `here `_.