Skip to content

Commit dcf836c

Browse files
authored
Use float32 on mps or npu in transformer_hidream_image's rope (#11316)
1 parent 1cb73cb commit dcf836c

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

src/diffusers/models/transformers/transformer_hidream_image.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,12 @@ def forward(self, latent):
9595
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
9696
assert dim % 2 == 0, "The dimension must be even."
9797

98-
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
98+
is_mps = pos.device.type == "mps"
99+
is_npu = pos.device.type == "npu"
100+
101+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
102+
103+
scale = torch.arange(0, dim, 2, dtype=dtype, device=pos.device) / dim
99104
omega = 1.0 / (theta**scale)
100105

101106
batch_size, seq_length = pos.shape

0 commit comments

Comments
 (0)