We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent ba8dc7d commit 54af3caCopy full SHA for 54af3ca
src/diffusers/utils/torch_utils.py
@@ -38,7 +38,7 @@ def maybe_allow_in_graph(cls):
38
def randn_tensor(
39
shape: Union[Tuple, List],
40
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
41
- device: Optional["torch.device"] = None,
+ device: Optional[Union[str, "torch.device"]] = None,
42
dtype: Optional["torch.dtype"] = None,
43
layout: Optional["torch.layout"] = None,
44
):
@@ -47,6 +47,8 @@ def randn_tensor(
47
is always created on the CPU.
48
"""
49
# device on which tensor is created defaults to device
50
+ if isinstance(device, str):
51
+ device = torch.device(device)
52
rand_device = device
53
batch_size = shape[0]
54
0 commit comments