-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
Once again there's a ROPE implementation with torch.float64 references which breaks when run against MPS and I assume nps as the last time I reported this it was fixed for both. This time it is in the PRX transformers code,
File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/diffusers/models/transformers/transformer_prx.py", line 278, in rope
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
The last one I reported was fixed with the following change
https://github.com/huggingface/diffusers/pull/11316/files
Reproduction
import torch
from diffusers.pipelines.prx import PRXPipeline
pipe = PRXPipeline.from_pretrained(
"Photoroom/prx-1024-t2i-beta",
torch_dtype=torch.bfloat16
).to("mps")
prompt = "A front-facing portrait of a lion in the golden savanna at sunset"
image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
image.save("lion.png")Logs
$ python prx.txt
Multiple distributions found for package optimum. Picked distribution: optimum-quanto
WARNING:torchao.kernel.intmm:Warning: Detected no triton, on systems without Triton certain kernels will not work
W1113 17:58:53.139000 2226 lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.
/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/torch/amp/autocast_mode.py:269: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling
warnings.warn(
model_index.json: 100%|█████████████████████████████████████████████████████████████████| 467/467 [00:00<00:00, 979kB/s]
config.json: 1.44kB [00:00, 3.16MB/s] | 0/14 [00:00<?, ?it/s]
model.safetensors.index.json: 22.7kB [00:00, 35.5MB/s] | 0.00/34.4M [00:00<?, ?B/s]
special_tokens_map.json: 100%|█████████████████████████████████████████████████████████| 636/636 [00:00<00:00, 12.2MB/s]
scheduler_config.json: 100%|███████████████████████████████████████████████████████████| 101/101 [00:00<00:00, 2.73MB/s]
tokenizer_config.json: 46.4kB [00:00, 336MB/s] | 2/14 [00:00<00:03, 4.00it/s]
config.json: 100%|█████████████████████████████████████████████████████████████████████| 252/252 [00:00<00:00, 7.19MB/s]
config.json: 100%|█████████████████████████████████████████████████████████████████████| 829/829 [00:00<00:00, 22.9MB/s]
tokenizer/tokenizer.json: 100%|████████████████████████████████████████████████████| 34.4M/34.4M [00:02<00:00, 16.7MB/s]
text_encoder/model-00003-of-00003.safete(…): 100%|███████████████████████████████████| 481M/481M [02:04<00:00, 3.87MB/s]
vae/diffusion_pytorch_model.safetensors: 100%|███████████████████████████████████████| 335M/335M [02:04<00:00, 2.70MB/s]
transformer/diffusion_pytorch_model.safe(…): 100%|█████████████████████████████████| 4.68G/4.68G [04:58<00:00, 15.7MB/s]
text_encoder/model-00002-of-00003.safete(…): 100%|█████████████████████████████████| 4.98G/4.98G [04:59<00:00, 16.6MB/s]
text_encoder/model-00001-of-00003.safete(…): 100%|█████████████████████████████████| 4.99G/4.99G [05:36<00:00, 14.8MB/s]
Fetching 14 files: 100%|████████████████████████████████████████████████████████████████| 14/14 [05:42<00:00, 24.48s/it]
Loading pipeline components...: 0%| | 0/5 [00:00<?, ?it/s]Guiders are currently an experimental feature under active development. The API is subject to breaking changes in future releases.
Loading pipeline components...: 40%|█████████████████████▏ | 2/5 [00:08<00:14, 4.99s/it]`torch_dtype` is deprecated! Use `dtype` instead!
You are using a model of type t5_gemma_module to instantiate a model of type t5gemma. This is not supported for all configurations of models and can yield errors.
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 3/3 [00:22<00:00, 7.64s/it]
Loading pipeline components...: 100%|█████████████████████████████████████████████████████| 5/5 [00:32<00:00, 6.40s/it]
/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/diffusers/loaders/lora_pipeline.py:5089: FutureWarning: `LoraLoaderMixin` is deprecated and will be removed in version 1.0.0. LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead.
deprecate("LoraLoaderMixin", "1.0.0", deprecation_message)
0%| | 0/28 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/Volumes/SSD2TB/AI/Diffusers/prx.txt", line 10, in <module>
image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 122, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/diffusers/pipelines/prx/pipeline_prx.py", line 718, in __call__
noise_pred = self.transformer(
^^^^^^^^^^^^^^^^^
File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1780, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1791, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/diffusers/models/transformers/transformer_prx.py", line 739, in forward
pe = self.pe_embedder(img_ids)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1780, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1791, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/diffusers/models/transformers/transformer_prx.py", line 290, in forward
[self.rope(ids[:, :, i], self.axes_dim[i], self.theta) for i in range(n_axes)],
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/diffusers/models/transformers/transformer_prx.py", line 290, in <listcomp>
[self.rope(ids[:, :, i], self.axes_dim[i], self.theta) for i in range(n_axes)],
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/diffusers/models/transformers/transformer_prx.py", line 278, in rope
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.System Info
- 🤗 Diffusers version: 0.36.0.dev0
- Platform: macOS-26.1-arm64-arm-64bit
- Running on Google Colab?: No
- Python version: 3.11.13
- PyTorch version (GPU?): 2.10.0.dev20251023 (False)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.34.3
- Transformers version: 4.57.1
- Accelerate version: 1.7.0
- PEFT version: 0.17.0
- Bitsandbytes version: not installed
- Safetensors version: 0.5.3
- xFormers version: not installed
- Accelerator: Apple M3
- Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: No
Who can help?
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working