diff --git a/moondream/torch/layers.py b/moondream/torch/layers.py index 8d57c521..f51f016f 100644 --- a/moondream/torch/layers.py +++ b/moondream/torch/layers.py @@ -101,7 +101,12 @@ def unpack(self): del self.weight, self.bias quantize_(self, int4_weight_only(group_size=128)) self.unpacked = True - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + torch.xpu.empty_cache() + elif hasattr(torch, "mps") and torch.mps.is_available(): + torch.mps.empty_cache() def forward(self, x: torch.Tensor) -> torch.Tensor: if not self.unpacked: diff --git a/moondream/torch/moondream.py b/moondream/torch/moondream.py index 9cfcc0dc..a8fa1b2e 100644 --- a/moondream/torch/moondream.py +++ b/moondream/torch/moondream.py @@ -93,7 +93,11 @@ def _mask_mod(b, h, q, kv): class MoondreamModel(nn.Module): def __init__( - self, config: MoondreamConfig, dtype=torch.bfloat16, setup_caches=True + self, + config: MoondreamConfig, + dtype=torch.bfloat16, + setup_caches=True, + use_flex_decoding=None, ): super().__init__() self.config = config @@ -139,7 +143,9 @@ def __init__( attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1 self.register_buffer("attn_mask", attn_mask, persistent=False) - self.use_flex_decoding = True + if use_flex_decoding is None: + use_flex_decoding = torch.cuda.is_available() + self.use_flex_decoding = use_flex_decoding self._causal_block_mask = None self._point_gen_indices = None diff --git a/moondream/torch/sample.py b/moondream/torch/sample.py index 2c35268f..0a2c7abb 100644 --- a/moondream/torch/sample.py +++ b/moondream/torch/sample.py @@ -38,9 +38,10 @@ model.to(device, dtype=torch.bfloat16) model.compile() - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - torch.cuda.reset_accumulated_memory_stats() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + torch.cuda.reset_accumulated_memory_stats() # Encode image. image_path = args.image