Skip to content

Commit 16b70be

Browse files
authored
[BugFix] Fix various test failures (#2994)
1 parent 78b6026 commit 16b70be

File tree

7 files changed

+12
-15
lines changed

7 files changed

+12
-15
lines changed

sota-implementations/grpo/grpo_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# LICENSE file in the root directory of this source tree.
55
from __future__ import annotations
66

7-
import os
87
from typing import Any, Literal
98

109
import torch

test/test_cost.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14328,6 +14328,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
1432814328

1432914329

1433014330
class TestValues:
14331+
@pytest.mark.skipif(not _has_gym, reason="requires gym")
1433114332
def test_gae_multi_done(self):
1433214333

1433314334
# constants

torchrl/data/llm/chat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ def append(
430430
history = history.copy().clear_device_()
431431
else:
432432
history = history.to(self.device)
433-
return torch.stack(list(self.unbind(dim)) + [history], dim=dim)
433+
return lazy_stack(list(self.unbind(dim)) + [history], dim=dim)
434434

435435
def extend(
436436
self, history: History, *, inplace: bool = True, dim: int = 0

torchrl/envs/common.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import torch.nn as nn
1717
from tensordict import (
1818
is_tensor_collection,
19-
lazy_stack,
2019
LazyStackedTensorDict,
2120
TensorDictBase,
2221
unravel_key,
@@ -3326,7 +3325,9 @@ def rollout(
33263325
)
33273326
raise
33283327
else:
3329-
out_td = lazy_stack(tensordicts, len(batch_size), out=out)
3328+
out_td = LazyStackedTensorDict.maybe_dense_stack(
3329+
tensordicts, len(batch_size), out=out
3330+
)
33303331
if set_truncated:
33313332
found_truncated = False
33323333
for key in self.done_keys:

torchrl/envs/llm/transforms/browser.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,11 @@
88
from __future__ import annotations
99

1010
import asyncio
11-
import json
12-
import re
13-
import signal
14-
from contextlib import asynccontextmanager
15-
from typing import Any, Optional
11+
from typing import Any
1612
from urllib.parse import urlparse
1713

18-
from playwright.async_api import async_playwright
1914
from tensordict import TensorDictBase
2015

21-
from torchrl.data.llm import History
2216
from torchrl.envs.llm.transforms.tools import MCPToolTransform
2317

2418
# Schema for the browser tool
@@ -147,6 +141,8 @@ def __init__(
147141

148142
async def _init_browser(self):
149143
"""Initialize the browser if not already initialized."""
144+
from playwright.async_api import async_playwright
145+
150146
if self.browser is None:
151147
playwright = await async_playwright().start()
152148
self.browser = await playwright.chromium.launch(headless=self.headless)
@@ -213,7 +209,7 @@ async def _scroll(self, amount: int) -> dict[str, Any]:
213209
return {"success": False, "error": str(e)}
214210

215211
async def _extract(
216-
self, selector: str, extract_type: str, attribute: Optional[str] = None
212+
self, selector: str, extract_type: str, attribute: str | None = None
217213
) -> dict[str, Any]:
218214
"""Extract content from the page."""
219215
try:

torchrl/envs/llm/transforms/tools.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import os
99
import queue
1010
import re
11-
import signal
1211
import subprocess
1312
import tempfile
1413
import threading
@@ -310,6 +309,8 @@ def execute(self, prompt: str) -> dict[str, any]:
310309

311310
def cleanup(self):
312311
"""Clean up the persistent process."""
312+
import signal
313+
313314
if self.process:
314315
try:
315316
self.process.send_signal(signal.SIGTERM)

tutorials/sphinx-tutorials/llm_browser.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@
6363
from __future__ import annotations
6464

6565
import warnings
66-
from pprint import pprint
6766

6867
import torch
6968

@@ -210,7 +209,7 @@ def execute_tool_action(
210209
print(action)
211210
print("\nEnvironment Response:")
212211
print("--------------------")
213-
pprint(s_["history"].apply_chat_template(tokenizer=env.tokenizer))
212+
torchrl_logger.info(s_["history"].apply_chat_template(tokenizer=env.tokenizer))
214213

215214
return s, s_
216215

0 commit comments

Comments
 (0)