Skip to content

Commit e88922f

Browse files
committed
Fixed issue of DynamicCache
Signed-off-by: Amit Raj <[email protected]>
1 parent e4503c5 commit e88922f

File tree

9 files changed

+24
-20
lines changed

9 files changed

+24
-20
lines changed

QEfficient/transformers/models/codegen/modeling_codegen.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch
1313
import torch.utils.checkpoint
1414
from torch import nn
15-
from transformers.cache_utils import Cache, DynamicCache
15+
from transformers.cache_utils import Cache
1616
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
1717
from transformers.models.codegen.modeling_codegen import (
1818
CodeGenAttention,
@@ -22,6 +22,7 @@
2222
apply_rotary_pos_emb,
2323
)
2424

25+
from QEfficient.transformers.cache_utils import QEffDynamicCache
2526
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
2627

2728

@@ -131,7 +132,7 @@ def forward(
131132
"position_ids": position_ids,
132133
"batch_index": batch_index,
133134
}
134-
pkv = DynamicCache()
135+
pkv = QEffDynamicCache()
135136
pkv.key_cache.append(past_key_value[0])
136137
pkv.value_cache.append(past_key_value[1])
137138
key, value = pkv.update(key, value, 0, cache_kwargs)

QEfficient/transformers/models/falcon/modeling_falcon.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import torch
1414
import torch.utils.checkpoint
1515
from torch.nn import functional as F
16-
from transformers.cache_utils import Cache, DynamicCache
16+
from transformers.cache_utils import Cache
1717
from transformers.modeling_outputs import (
1818
BaseModelOutputWithPastAndCrossAttentions,
1919
CausalLMOutputWithCrossAttentions,
@@ -29,6 +29,7 @@
2929
rotate_half,
3030
)
3131

32+
from QEfficient.transformers.cache_utils import QEffDynamicCache
3233
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
3334

3435

@@ -284,7 +285,7 @@ def forward(
284285
return_legacy_cache = False
285286
if use_cache and not isinstance(past_key_values, Cache):
286287
return_legacy_cache = True
287-
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
288+
past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values)
288289

289290
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
290291
batch_size, seq_length, _ = inputs_embeds.shape

QEfficient/transformers/models/gemma/modeling_gemma.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import torch
1111
from torch import nn
12-
from transformers.cache_utils import Cache, DynamicCache
12+
from transformers.cache_utils import Cache
1313
from transformers.modeling_outputs import (
1414
BaseModelOutputWithPast,
1515
CausalLMOutputWithPast,
@@ -25,6 +25,7 @@
2525
rotate_half,
2626
)
2727

28+
from QEfficient.transformers.cache_utils import QEffDynamicCache
2829
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
2930

3031

@@ -290,7 +291,7 @@ def forward(
290291
return_legacy_cache = False
291292
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
292293
return_legacy_cache = True
293-
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
294+
past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values)
294295

295296
if cache_position is None:
296297
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0

QEfficient/transformers/models/gemma2/modeling_gemma2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import torch
1111
from torch import nn
12-
from transformers.cache_utils import Cache, DynamicCache
12+
from transformers.cache_utils import Cache
1313
from transformers.generation import GenerationMixin
1414
from transformers.modeling_outputs import (
1515
BaseModelOutputWithPast,
@@ -26,6 +26,8 @@
2626
rotate_half,
2727
)
2828

29+
from QEfficient.transformers.cache_utils import QEffDynamicCache
30+
2931
# from transformers.utils import is_torchdynamo_compiling
3032
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
3133

@@ -298,7 +300,7 @@ def forward(
298300
return_legacy_cache = False
299301
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
300302
return_legacy_cache = True
301-
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
303+
past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values)
302304

303305
if cache_position is None:
304306
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0

QEfficient/transformers/models/llama/modeling_llama.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import torch
1111
from torch import nn
12-
from transformers.cache_utils import Cache, DynamicCache
12+
from transformers.cache_utils import Cache
1313
from transformers.modeling_outputs import (
1414
BaseModelOutputWithPast,
1515
CausalLMOutputWithPast,
@@ -25,6 +25,7 @@
2525
rotate_half,
2626
)
2727

28+
from QEfficient.transformers.cache_utils import QEffDynamicCache
2829
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
2930

3031

@@ -266,7 +267,7 @@ def forward(
266267
return_legacy_cache = False
267268
if use_cache and not isinstance(past_key_values, Cache):
268269
return_legacy_cache = True
269-
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
270+
past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values)
270271

271272
if cache_position is None:
272273
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0

QEfficient/transformers/models/mistral/modeling_mistral.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch
1313
import torch.utils.checkpoint
1414
from torch import nn
15-
from transformers.cache_utils import Cache, DynamicCache
15+
from transformers.cache_utils import Cache
1616
from transformers.modeling_outputs import (
1717
BaseModelOutputWithPast,
1818
CausalLMOutputWithPast,
@@ -29,6 +29,7 @@
2929
rotate_half,
3030
)
3131

32+
from QEfficient.transformers.cache_utils import QEffDynamicCache
3233
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
3334

3435

@@ -292,7 +293,7 @@ def forward(
292293

293294
return_legacy_cache = False
294295
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
295-
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
296+
past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values)
296297
return_legacy_cache = True
297298
logger.warning_once(
298299
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "

QEfficient/transformers/models/mpt/modeling_mpt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
import torch
1313
import torch.utils.checkpoint
1414
from torch import nn
15-
from transformers.cache_utils import DynamicCache
1615
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
1716
from transformers.modeling_outputs import (
1817
BaseModelOutputWithPastAndCrossAttentions,
1918
CausalLMOutputWithCrossAttentions,
2019
)
2120
from transformers.models.mpt.modeling_mpt import MptAttention, MptBlock, MptForCausalLM, MptModel
2221

22+
from QEfficient.transformers.cache_utils import QEffDynamicCache
2323
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
2424

2525

@@ -51,7 +51,7 @@ def forward(
5151
if past_key_value is not None:
5252
if len(past_key_value) != 0:
5353
cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index}
54-
pkv = DynamicCache()
54+
pkv = QEffDynamicCache()
5555
pkv.key_cache.append(past_key_value[0])
5656
pkv.value_cache.append(past_key_value[1])
5757
key_states, value_states = pkv.update(key_states, value_states, 0, cache_kwargs)

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from types import MethodType
99
from typing import Tuple
1010

11-
import transformers
1211
from torch import nn
1312
from transformers.models.codegen.modeling_codegen import (
1413
CodeGenAttention,
@@ -121,7 +120,6 @@
121120

122121
from QEfficient.base.pytorch_transforms import ModuleMappingTransform, ModuleMethodMapperTransform
123122
from QEfficient.customop import CustomRMSNormAIC, GemmaCustomRMSNormAIC
124-
from QEfficient.transformers.cache_utils import QEffDynamicCache
125123
from QEfficient.transformers.models.codegen.modeling_codegen import (
126124
QEffCodeGenAttention,
127125
QeffCodeGenBlock,
@@ -370,8 +368,6 @@ class KVCacheTransform(ModuleMappingTransform):
370368
@classmethod
371369
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
372370
model, transformed = super().apply(model)
373-
# FIXME: see if we can merge into _module_mapping dict
374-
transformers.cache_utils.DynamicCache.update = QEffDynamicCache.update
375371
return model, transformed
376372

377373

QEfficient/transformers/models/qwen2/modeling_qwen2.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch
1313
import torch.utils.checkpoint
1414
from torch import nn
15-
from transformers.cache_utils import Cache, DynamicCache
15+
from transformers.cache_utils import Cache
1616
from transformers.modeling_outputs import (
1717
BaseModelOutputWithPast,
1818
CausalLMOutputWithPast,
@@ -28,6 +28,7 @@
2828
rotate_half,
2929
)
3030

31+
from QEfficient.transformers.cache_utils import QEffDynamicCache
3132
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
3233

3334

@@ -308,7 +309,7 @@ def forward(
308309
return_legacy_cache = False
309310
if use_cache and not isinstance(past_key_values, Cache):
310311
return_legacy_cache = True
311-
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
312+
past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values)
312313

313314
if inputs_embeds is None:
314315
inputs_embeds = self.embed_tokens(input_ids)

0 commit comments

Comments
 (0)