diff --git a/src/heartlib/heartcodec/models/flow_matching.py b/src/heartlib/heartcodec/models/flow_matching.py index c48d223..6f54e15 100644 --- a/src/heartlib/heartcodec/models/flow_matching.py +++ b/src/heartlib/heartcodec/models/flow_matching.py @@ -72,9 +72,14 @@ def inference_codes( batch_size = codes_bestrq_emb.shape[0] self.vq_embed.eval() - quantized_feature_emb = self.vq_embed.get_output_from_indices( - codes_bestrq_emb.transpose(1, 2) - ) + + # Clamp indices to valid codebook range to prevent CUDA out-of-bounds errors + # This is needed for compatibility with vector-quantize-pytorch >= 1.20 + indices_input = codes_bestrq_emb.transpose(1, 2) + codebook_size = getattr(self.vq_embed, 'codebook_size', 8192) + indices_input = indices_input.clamp(0, codebook_size - 1) + + quantized_feature_emb = self.vq_embed.get_output_from_indices(indices_input) quantized_feature_emb = self.cond_feature_emb(quantized_feature_emb) # b t 512 # assert 1==2 quantized_feature_emb = F.interpolate(