Skip to content

Commit

Permalink
Export T5 (encoder-decoder) to ExecuTorch
Browse files Browse the repository at this point in the history
  • Loading branch information
Guang Yang committed Mar 1, 2025
1 parent 92c5ca9 commit 932354c
Show file tree
Hide file tree
Showing 2 changed files with 296 additions and 0 deletions.
144 changes: 144 additions & 0 deletions src/transformers/integrations/executorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,147 @@ def convert_and_export_with_cache(
strict=True,
)
return exported_program

class T5EncoderExportableModule(torch.nn.Module):
"""
A wrapper module designed to make a T5EncoderModel exportable with `torch.export`.
This module ensures that the exported encoder model is compatible with ExecuTorch.
"""
def __init__(self, encoder_model):
super().__init__()
self.encoder = encoder_model

def forward(self, input_ids):
return self.encoder(input_ids=input_ids).last_hidden_state

class T5DecoderExportableModuleWithStaticCache(torch.nn.Module):
"""
A wrapper module designed to make a T5 decoder exportable with `torch.export`,
specifically for use with static caching. This module ensures the exported decoder
is compatible with ExecuTorch.
"""
def __init__(self, model, max_cache_length, batch_size):
super().__init__()

# Get the decoder component
self.decoder = model.get_decoder()
self.lm_head = model.lm_head
self.config = model.config

# Initialize static cache
self.static_cache = StaticCache(
config=self.config,
batch_size=batch_size,
max_cache_len=max_cache_length,
device="cpu",
dtype=torch.float32,
)

# Register cache buffers to make them exportable
for i in range(len(self.static_cache.key_cache)):
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False)
self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False)

def forward(self, decoder_input_ids, encoder_hidden_states, cache_position):
# Get outputs from decoder
outputs = self.decoder(
input_ids=decoder_input_ids,
encoder_hidden_states=encoder_hidden_states,
past_key_values=self.static_cache,
use_cache=True,
cache_position=cache_position,
)

# Apply language model head
lm_logits = self.lm_head(outputs[0])

return lm_logits

class T5ExportableModule(torch.nn.Module):
def __init__(self, model, batch_size = 1, max_hidden_seq_length = 4096, max_cache_length = 1024):
super().__init__()

self.full_model = model
self.encoder = model.get_encoder()
self.config = model.config
self.max_hidden_seq_length = max_hidden_seq_length
self.max_cache_length = max_cache_length
self.batch_size = batch_size
self.exported_encoder = None
self.exported_decoder = None

def _export_encoder(self, encoder_input_ids):
wrapped_encoder = T5EncoderExportableModule(self.encoder).to("cpu").eval()

# Define dynamic sequence length for encoder
seq_len_dim = torch.export.Dim("encoder_seq_length", max=self.max_hidden_seq_length)

# Export the encoder
with torch.no_grad():
exported_encoder = torch.export.export(
wrapped_encoder,
(encoder_input_ids,),
dynamic_shapes={"input_ids": {1: seq_len_dim}},
strict=True
)

return exported_encoder

def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_position):
wrapped_decoder = T5DecoderExportableModuleWithStaticCache(self.full_model, self.max_cache_length, self.batch_size).to("cpu").eval()

# Define dynamic dimension for encoder output sequence length
encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length)

# Export the decoder
with torch.no_grad():
exported_decoder = torch.export.export(
wrapped_decoder,
(decoder_input_ids, encoder_hidden_states, cache_position),
dynamic_shapes={
"decoder_input_ids": None,
"encoder_hidden_states": {1: encoder_seq_len_dim},
"cache_position": None,
},
strict=True
)

return exported_decoder

def export(self, encoder_input_ids, decoder_input_ids, encoder_hidden_states, cache_position):
self.exported_encoder = self._export_encoder(encoder_input_ids)
self.exported_decoder = self._export_decoder(decoder_input_ids, encoder_hidden_states, cache_position)

# Return self to allow chaining
return self

def generate(self, prompt_token_ids, max_new_tokens):
with torch.no_grad():
# Run encoder
encoder_output = self.exported_encoder.module()(prompt_token_ids)

# Initialize with start token (0 for T5)
decoder_input_ids = torch.tensor([[0]], dtype=torch.long)
generated_ids = [0]

# Generate tokens one by one
for i in range(max_new_tokens - 1):
# Run decoder for next token prediction
logits = self.exported_decoder.module()(
decoder_input_ids,
encoder_output,
torch.tensor([i], dtype=torch.long)
)

# Get next token
next_token = torch.argmax(logits[:, -1, :], dim=-1).item()
generated_ids.append(next_token)

# Update input for next iteration
decoder_input_ids = torch.tensor([[next_token]], dtype=torch.long)

# Check if EOS token
if next_token == self.config.eos_token_id:
break

return generated_ids
152 changes: 152 additions & 0 deletions tests/models/t5/test_modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from transformers import T5Config, is_torch_available
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
from transformers.testing_utils import (
require_accelerate,
require_sentencepiece,
Expand Down Expand Up @@ -1689,6 +1690,157 @@ def test_compile_static_cache_encoder(self):
logits_compiled = model(**inputs)
torch.testing.assert_close(logits[0][:, -3:, -3], logits_compiled[0][:, -3:, -3], rtol=1e-5, atol=1e-5)

@slow
def test_export_encoder(self):
"""Test exporting T5EncoderModel to torch export format."""
if not is_torch_greater_or_equal_than_2_4:
self.skipTest("This test requires torch >= 2.4 to run.")

from transformers.integrations.executorch import T5EncoderExportableModule

model_id = "google-t5/t5-small"
device = "cpu"
tokenizer = T5Tokenizer.from_pretrained(model_id)
# Create test input
input_text = "Studies have shown that owning a dog is good for you."
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)

# Load model
model = T5EncoderModel.from_pretrained(model_id).to(device=device).eval()

# Get original output for comparison
with torch.no_grad():
original_output = model(input_ids=input_ids).last_hidden_state

encoder_model = T5EncoderExportableModule(model)

# Export the encoder_model
with torch.no_grad():
seq_len_dim = torch.export.Dim("sequence_length", max=4096)

exported_program = torch.export.export(
encoder_model,
(input_ids,),
dynamic_shapes={"input_ids": {1: seq_len_dim}},
strict=True
)

# Test the exported model
with torch.no_grad():
exported_output = exported_program.module()(input_ids)

# Verify outputs are close enough
self.assertTrue(torch.allclose(original_output, exported_output, atol=1e-5))

@slow
def test_export_decoder(self):
"""Test exporting T5 decoder with static cache to torch export format."""
if not is_torch_greater_or_equal_than_2_4:
self.skipTest("This test requires torch >= 2.4 to run.")

from transformers.integrations.executorch import T5DecoderExportableModuleWithStaticCache

model_id = "google-t5/t5-small"

# Configuration for static cache
batch_size = 1
max_cache_len = 123
device = "cpu"

full_model = T5ForConditionalGeneration.from_pretrained(model_id).to(device)
decoder_model = T5DecoderExportableModuleWithStaticCache(full_model, max_cache_len, batch_size).to(device).eval()

# Prepare test inputs
example_decoder_input_ids = torch.tensor([[0]], dtype=torch.long) # Start token
example_cache_position = torch.tensor([0], dtype=torch.long)

# For T5-small, hidden size is 512
example_encoder_hidden_states = torch.zeros((batch_size, 10, 512), dtype=torch.float32)

# Export the model
with torch.no_grad():
encoder_sequence_length_dim = torch.export.Dim("encoder_sequence_length", max=4096)

exported_program = torch.export.export(
decoder_model,
(example_decoder_input_ids, example_encoder_hidden_states, example_cache_position),
dynamic_shapes={
"decoder_input_ids": None,
"encoder_hidden_states": {1: encoder_sequence_length_dim},
"cache_position": None,
},
strict=True
)

# We won't directly verify outputs here as it's complicated with caching,
# but we'll check the export was successful
self.assertIsNotNone(exported_program)

# Verify cache buffers existence and shapes
cache_buffers = [(name, buffer) for name, buffer in exported_program.named_buffers()
if name.startswith("key_cache_") or name.startswith("value_cache_")]

# Verify cache buffers
self.assertTrue(len(cache_buffers) > 0, "No cache buffers found in exported model")
for name, buffer in cache_buffers:
# Verify cache buffers are 3D
self.assertEqual(buffer.shape[2], max_cache_len)

@slow
def test_export_t5_summarization(self):
"""Test composing exported T5 encoder and decoder for summarization."""
if not is_torch_greater_or_equal_than_2_4:
self.skipTest("This test requires torch >= 2.4 to run.")

from transformers.integrations.executorch import T5ExportableModule

batch_size = 1
max_cache_length = 1234
max_hidden_seq_length = 5678
model_id = "google-t5/t5-small"

tokenizer = T5Tokenizer.from_pretrained(model_id)
full_model = T5ForConditionalGeneration.from_pretrained(model_id).eval()
wrapped_model = T5ExportableModule(full_model, batch_size=batch_size, max_hidden_seq_length=max_hidden_seq_length, max_cache_length=max_cache_length)

# Prepare example inputs
example_encoder_input_ids = tokenizer("Test input", return_tensors="pt").input_ids
example_decoder_input_ids = torch.tensor([[0]], dtype=torch.long) # Start token
example_cache_position = torch.tensor([0], dtype=torch.long)

# Get expected hidden size from config
hidden_size = full_model.config.d_model # 512 for T5-small
example_encoder_hidden_states = torch.zeros((batch_size, 10, hidden_size), dtype=torch.float32)

exported_t5 = wrapped_model.export(
encoder_input_ids=example_encoder_input_ids,
decoder_input_ids=example_decoder_input_ids,
encoder_hidden_states=example_encoder_hidden_states,
cache_position=example_cache_position,
)

# Test Summarization with Composed Models
prompts = [
"summarize: Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial "
"reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe "
"theory of relativity is not hard to grasp."
]
input_ids = tokenizer(prompts, return_tensors="pt").input_ids

generated_ids = exported_t5.generate(prompt_token_ids=input_ids, max_new_tokens=max_cache_length)
generated_summary = tokenizer.decode(generated_ids, skip_special_tokens=True)
print(f"Generated summary: {generated_summary}")

# Also run original model for comparison
original_model = T5ForConditionalGeneration.from_pretrained(model_id).eval()
with torch.no_grad():
original_outputs = original_model.generate(input_ids, max_length=50, num_beams=1)
original_summary = tokenizer.decode(original_outputs[0], skip_special_tokens=True)
print(f"Original model summary: {original_summary}")

# Basic verification that we got a reasonable summary
self.assertEqual(generated_summary, original_summary)


@require_torch
class TestAsymmetricT5(unittest.TestCase):
Expand Down

0 comments on commit 932354c

Please sign in to comment.