From ae848700da0f72c1981e0983d7cb763ebb56502e Mon Sep 17 00:00:00 2001 From: Eduardo Solanas Date: Thu, 7 May 2026 15:15:40 +0000 Subject: [PATCH] fix: make torch.cuda.empty_cache() and flex_attention device-agnostic MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - layers.py: Guard torch.cuda.empty_cache() with device checks (CUDA → XPU → MPS) Prevents RuntimeError on Intel Arc GPUs and Apple Silicon where CUDA is unavailable. - moondream.py: Add use_flex_decoding constructor parameter with auto-detection Defaults to True on CUDA, False otherwise. Allows explicit override for non-CUDA devices (Intel XPU, Apple MPS, AMD ROCm). Fixes #316 and #335. - sample.py: Guard CUDA memory stats calls with is_available() check Prevents crashes when running the demo script on non-NVIDIA hardware. Closes #335 --- moondream/torch/layers.py | 7 ++++++- moondream/torch/moondream.py | 10 ++++++++-- moondream/torch/sample.py | 7 ++++--- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/moondream/torch/layers.py b/moondream/torch/layers.py index 8d57c521..f51f016f 100644 --- a/moondream/torch/layers.py +++ b/moondream/torch/layers.py @@ -101,7 +101,12 @@ def unpack(self): del self.weight, self.bias quantize_(self, int4_weight_only(group_size=128)) self.unpacked = True - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + torch.xpu.empty_cache() + elif hasattr(torch, "mps") and torch.mps.is_available(): + torch.mps.empty_cache() def forward(self, x: torch.Tensor) -> torch.Tensor: if not self.unpacked: diff --git a/moondream/torch/moondream.py b/moondream/torch/moondream.py index 9cfcc0dc..a8fa1b2e 100644 --- a/moondream/torch/moondream.py +++ b/moondream/torch/moondream.py @@ -93,7 +93,11 @@ def _mask_mod(b, h, q, kv): class MoondreamModel(nn.Module): def __init__( - self, config: MoondreamConfig, dtype=torch.bfloat16, setup_caches=True + self, + config: MoondreamConfig, + dtype=torch.bfloat16, + setup_caches=True, + use_flex_decoding=None, ): super().__init__() self.config = config @@ -139,7 +143,9 @@ def __init__( attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1 self.register_buffer("attn_mask", attn_mask, persistent=False) - self.use_flex_decoding = True + if use_flex_decoding is None: + use_flex_decoding = torch.cuda.is_available() + self.use_flex_decoding = use_flex_decoding self._causal_block_mask = None self._point_gen_indices = None diff --git a/moondream/torch/sample.py b/moondream/torch/sample.py index 2c35268f..0a2c7abb 100644 --- a/moondream/torch/sample.py +++ b/moondream/torch/sample.py @@ -38,9 +38,10 @@ model.to(device, dtype=torch.bfloat16) model.compile() - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - torch.cuda.reset_accumulated_memory_stats() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + torch.cuda.reset_accumulated_memory_stats() # Encode image. image_path = args.image