Skip to content

fix: make torch.cuda.empty_cache() and flex_attention device-agnostic (Intel XPU, Apple MPS)#336

Open
EduardoSolanas wants to merge 1 commit into
m87-labs:mainfrom
EduardoSolanas:fix/intel-xpu-device-agnostic-cache-flex
Open

fix: make torch.cuda.empty_cache() and flex_attention device-agnostic (Intel XPU, Apple MPS)#336
EduardoSolanas wants to merge 1 commit into
m87-labs:mainfrom
EduardoSolanas:fix/intel-xpu-device-agnostic-cache-flex

Conversation

@EduardoSolanas
Copy link
Copy Markdown

Problem

Moondream3-preview crashes on non-CUDA devices (Intel Arc GPU, Apple MPS, AMD ROCm) due to two CUDA-specific code paths:

1. torch.cuda.empty_cache() crashes in production model code

QuantizedLinear.unpack() calls torch.cuda.empty_cache() unconditionally. On Intel XPU and Apple MPS, this raises RuntimeError or AttributeError.

2. flex_attention + create_block_mask is CUDA-only (fixes #316)

MoondreamModel.__init__ hardcodes self.use_flex_decoding = True, triggering create_block_mask and flex_attention imports that are CUDA-only APIs.

Changes

layers.py — Device-agnostic cache clearing

# Before:
torch.cuda.empty_cache()

# After:
if torch.cuda.is_available():
    torch.cuda.empty_cache()
elif hasattr(torch, "xpu") and torch.xpu.is_available():
    torch.xpu.empty_cache()
elif hasattr(torch, "mps") and torch.mps.is_available():
    torch.mps.empty_cache()

moondream.py — Configurable flex_decoding with auto-detection

# New constructor parameter with smart default:
def __init__(self, config, ..., use_flex_decoding=None):
    if use_flex_decoding is None:
        use_flex_decoding = torch.cuda.is_available()  # auto-detect
    self.use_flex_decoding = use_flex_decoding
  • CUDA users: No change — flex_decoding auto-enabled ✓
  • Intel XPU / Apple MPS: flex_decoding auto-disabled, falls back to F.scaled_dot_product_attention
  • Power users: Can override: MoondreamModel(config, use_flex_decoding=False)

sample.py — Guard benchmark memory stats

CUDA-specific memory stat calls gated behind torch.cuda.is_available().

Why this approach (vs alternatives)

  • Auto-detection means the model "just works" on Intel, Apple, and AMD without manual configuration
  • Explicit override preserves flexibility for users who want to force flex_attention on platforms with experimental support
  • Three-backend cache clearing covers CUDA, XPU, and MPS — the three major PyTorch backends
  • Minimal diff (18 insertions, 6 deletions) — only touches the exact call sites that crash

Testing

Tested on Intel Arc A770 with PyTorch 2.6.0+xpu (Intel oneAPI 2025.0.2):

  • ✅ Model loads without CUDA errors
  • use_flex_decoding auto-set to False
  • QuantizedLinear.unpack() calls torch.xpu.empty_cache() correctly
  • ✅ Inference works via F.scaled_dot_product_attention fallback

Related

- layers.py: Guard torch.cuda.empty_cache() with device checks (CUDA → XPU → MPS)
  Prevents RuntimeError on Intel Arc GPUs and Apple Silicon where CUDA is unavailable.

- moondream.py: Add use_flex_decoding constructor parameter with auto-detection
  Defaults to True on CUDA, False otherwise. Allows explicit override for
  non-CUDA devices (Intel XPU, Apple MPS, AMD ROCm).
  Fixes m87-labs#316 and m87-labs#335.

- sample.py: Guard CUDA memory stats calls with is_available() check
  Prevents crashes when running the demo script on non-NVIDIA hardware.

Closes m87-labs#335
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

1 participant