Skip to content

PRX transformers version of rope has a float64 reference, breaks MPS (and npu ?) #12653

@Vargol

Description

@Vargol

Describe the bug

Once again there's a ROPE implementation with torch.float64 references which breaks when run against MPS and I assume nps as the last time I reported this it was fixed for both. This time it is in the PRX transformers code,

  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/diffusers/models/transformers/transformer_prx.py", line 278, in rope
    scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

The last one I reported was fixed with the following change
https://github.com/huggingface/diffusers/pull/11316/files

Reproduction

import torch
from diffusers.pipelines.prx import PRXPipeline

pipe = PRXPipeline.from_pretrained(
    "Photoroom/prx-1024-t2i-beta",
    torch_dtype=torch.bfloat16
).to("mps")

prompt = "A front-facing portrait of a lion in the golden savanna at sunset"
image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
image.save("lion.png")

Logs

$ python prx.txt 
Multiple distributions found for package optimum. Picked distribution: optimum-quanto
WARNING:torchao.kernel.intmm:Warning: Detected no triton, on systems without Triton certain kernels will not work
W1113 17:58:53.139000 2226 lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.
/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/torch/amp/autocast_mode.py:269: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling
  warnings.warn(
model_index.json: 100%|█████████████████████████████████████████████████████████████████| 467/467 [00:00<00:00, 979kB/s]
config.json: 1.44kB [00:00, 3.16MB/s]                                                            | 0/14 [00:00<?, ?it/s]
model.safetensors.index.json: 22.7kB [00:00, 35.5MB/s]                                      | 0.00/34.4M [00:00<?, ?B/s]
special_tokens_map.json: 100%|█████████████████████████████████████████████████████████| 636/636 [00:00<00:00, 12.2MB/s]
scheduler_config.json: 100%|███████████████████████████████████████████████████████████| 101/101 [00:00<00:00, 2.73MB/s]
tokenizer_config.json: 46.4kB [00:00, 336MB/s]                                           | 2/14 [00:00<00:03,  4.00it/s]
config.json: 100%|█████████████████████████████████████████████████████████████████████| 252/252 [00:00<00:00, 7.19MB/s]
config.json: 100%|█████████████████████████████████████████████████████████████████████| 829/829 [00:00<00:00, 22.9MB/s]
tokenizer/tokenizer.json: 100%|████████████████████████████████████████████████████| 34.4M/34.4M [00:02<00:00, 16.7MB/s]
text_encoder/model-00003-of-00003.safete(…): 100%|███████████████████████████████████| 481M/481M [02:04<00:00, 3.87MB/s]
vae/diffusion_pytorch_model.safetensors: 100%|███████████████████████████████████████| 335M/335M [02:04<00:00, 2.70MB/s]
transformer/diffusion_pytorch_model.safe(…): 100%|█████████████████████████████████| 4.68G/4.68G [04:58<00:00, 15.7MB/s]
text_encoder/model-00002-of-00003.safete(…): 100%|█████████████████████████████████| 4.98G/4.98G [04:59<00:00, 16.6MB/s]
text_encoder/model-00001-of-00003.safete(…): 100%|█████████████████████████████████| 4.99G/4.99G [05:36<00:00, 14.8MB/s]
Fetching 14 files: 100%|████████████████████████████████████████████████████████████████| 14/14 [05:42<00:00, 24.48s/it]
Loading pipeline components...:   0%|                                                             | 0/5 [00:00<?, ?it/s]Guiders are currently an experimental feature under active development. The API is subject to breaking changes in future releases.
Loading pipeline components...:  40%|█████████████████████▏                               | 2/5 [00:08<00:14,  4.99s/it]`torch_dtype` is deprecated! Use `dtype` instead!
You are using a model of type t5_gemma_module to instantiate a model of type t5gemma. This is not supported for all configurations of models and can yield errors.
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 3/3 [00:22<00:00,  7.64s/it]
Loading pipeline components...: 100%|█████████████████████████████████████████████████████| 5/5 [00:32<00:00,  6.40s/it]
/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/diffusers/loaders/lora_pipeline.py:5089: FutureWarning: `LoraLoaderMixin` is deprecated and will be removed in version 1.0.0. LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead.
  deprecate("LoraLoaderMixin", "1.0.0", deprecation_message)
  0%|                                                                                            | 0/28 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/Volumes/SSD2TB/AI/Diffusers/prx.txt", line 10, in <module>
    image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 122, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/diffusers/pipelines/prx/pipeline_prx.py", line 718, in __call__
    noise_pred = self.transformer(
                 ^^^^^^^^^^^^^^^^^
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1780, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1791, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/diffusers/models/transformers/transformer_prx.py", line 739, in forward
    pe = self.pe_embedder(img_ids)
         ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1780, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1791, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/diffusers/models/transformers/transformer_prx.py", line 290, in forward
    [self.rope(ids[:, :, i], self.axes_dim[i], self.theta) for i in range(n_axes)],
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/diffusers/models/transformers/transformer_prx.py", line 290, in <listcomp>
    [self.rope(ids[:, :, i], self.axes_dim[i], self.theta) for i in range(n_axes)],
     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/diffusers/models/transformers/transformer_prx.py", line 278, in rope
    scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

System Info

  • 🤗 Diffusers version: 0.36.0.dev0
  • Platform: macOS-26.1-arm64-arm-64bit
  • Running on Google Colab?: No
  • Python version: 3.11.13
  • PyTorch version (GPU?): 2.10.0.dev20251023 (False)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.34.3
  • Transformers version: 4.57.1
  • Accelerate version: 1.7.0
  • PEFT version: 0.17.0
  • Bitsandbytes version: not installed
  • Safetensors version: 0.5.3
  • xFormers version: not installed
  • Accelerator: Apple M3
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help?

@DN6 @yiyixuxu @sayakpaul

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions