diff --git a/moondream/torch/hf_moondream.py b/moondream/torch/hf_moondream.py index 3432d869..6bffee51 100644 --- a/moondream/torch/hf_moondream.py +++ b/moondream/torch/hf_moondream.py @@ -29,9 +29,10 @@ class HfConfig(PretrainedConfig): _auto_class = "AutoConfig" model_type = "moondream3" - def __init__(self, **kwargs): + def __init__(self, use_flex_decoding=True, **kwargs): super().__init__(**kwargs) self.config = {"skills": ["query", "caption", "detect", "point"]} + self.use_flex_decoding = use_flex_decoding class HfMoondream(PreTrainedModel): @@ -40,8 +41,9 @@ class HfMoondream(PreTrainedModel): def __init__(self, config): super().__init__(config) + use_flex_decoding = getattr(config, 'use_flex_decoding', True) self.model = MoondreamModel( - MoondreamConfig.from_dict(config.config), setup_caches=False + MoondreamConfig.from_dict(config.config), setup_caches=False, use_flex_decoding=use_flex_decoding ) self._is_kv_cache_setup = False diff --git a/moondream/torch/moondream.py b/moondream/torch/moondream.py index 9cfcc0dc..f2b04012 100644 --- a/moondream/torch/moondream.py +++ b/moondream/torch/moondream.py @@ -93,7 +93,7 @@ def _mask_mod(b, h, q, kv): class MoondreamModel(nn.Module): def __init__( - self, config: MoondreamConfig, dtype=torch.bfloat16, setup_caches=True + self, config: MoondreamConfig, dtype=torch.bfloat16, setup_caches=True, use_flex_decoding=True ): super().__init__() self.config = config @@ -139,7 +139,7 @@ def __init__( attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1 self.register_buffer("attn_mask", attn_mask, persistent=False) - self.use_flex_decoding = True + self.use_flex_decoding = use_flex_decoding self._causal_block_mask = None self._point_gen_indices = None