Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export T5 (encoder-decoder) to ExecuTorch #36486

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions src/transformers/integrations/executorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,147 @@ def convert_and_export_with_cache(
strict=True,
)
return exported_program

class T5EncoderExportableModule(torch.nn.Module):
"""
A wrapper module designed to make a T5EncoderModel exportable with `torch.export`.
This module ensures that the exported encoder model is compatible with ExecuTorch.
"""
def __init__(self, encoder_model):
super().__init__()
self.encoder = encoder_model

def forward(self, input_ids):
return self.encoder(input_ids=input_ids).last_hidden_state

class T5DecoderExportableModuleWithStaticCache(torch.nn.Module):
"""
A wrapper module designed to make a T5 decoder exportable with `torch.export`,
specifically for use with static caching. This module ensures the exported decoder
is compatible with ExecuTorch.
"""
def __init__(self, model, max_cache_length, batch_size):
super().__init__()

# Get the decoder component
self.decoder = model.get_decoder()
self.lm_head = model.lm_head
self.config = model.config

# Initialize static cache
self.static_cache = StaticCache(
config=self.config,
batch_size=batch_size,
max_cache_len=max_cache_length,
device="cpu",
dtype=torch.float32,
)

# Register cache buffers to make them exportable
for i in range(len(self.static_cache.key_cache)):
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False)
self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False)

def forward(self, decoder_input_ids, encoder_hidden_states, cache_position):
# Get outputs from decoder
outputs = self.decoder(
input_ids=decoder_input_ids,
encoder_hidden_states=encoder_hidden_states,
past_key_values=self.static_cache,
use_cache=True,
cache_position=cache_position,
)

# Apply language model head
lm_logits = self.lm_head(outputs[0])

return lm_logits

class T5ExportableModule(torch.nn.Module):
def __init__(self, model, batch_size = 1, max_hidden_seq_length = 4096, max_cache_length = 1024):
super().__init__()

self.full_model = model
self.encoder = model.get_encoder()
self.config = model.config
self.max_hidden_seq_length = max_hidden_seq_length
self.max_cache_length = max_cache_length
self.batch_size = batch_size
self.exported_encoder = None
self.exported_decoder = None

def _export_encoder(self, encoder_input_ids):
wrapped_encoder = T5EncoderExportableModule(self.encoder).to("cpu").eval()

# Define dynamic sequence length for encoder
seq_len_dim = torch.export.Dim("encoder_seq_length", max=self.max_hidden_seq_length)

# Export the encoder
with torch.no_grad():
exported_encoder = torch.export.export(
wrapped_encoder,
(encoder_input_ids,),
dynamic_shapes={"input_ids": {1: seq_len_dim}},
strict=True
)

return exported_encoder

def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_position):
wrapped_decoder = T5DecoderExportableModuleWithStaticCache(self.full_model, self.max_cache_length, self.batch_size).to("cpu").eval()

# Define dynamic dimension for encoder output sequence length
encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length)

# Export the decoder
with torch.no_grad():
exported_decoder = torch.export.export(
wrapped_decoder,
(decoder_input_ids, encoder_hidden_states, cache_position),
dynamic_shapes={
"decoder_input_ids": None,
"encoder_hidden_states": {1: encoder_seq_len_dim},
"cache_position": None,
},
strict=True
)

return exported_decoder

def export(self, encoder_input_ids, decoder_input_ids, encoder_hidden_states, cache_position):
self.exported_encoder = self._export_encoder(encoder_input_ids)
self.exported_decoder = self._export_decoder(decoder_input_ids, encoder_hidden_states, cache_position)

# Return self to allow chaining
return self

def generate(self, prompt_token_ids, max_new_tokens):
with torch.no_grad():
# Run encoder
encoder_output = self.exported_encoder.module()(prompt_token_ids)

# Initialize with start token (0 for T5)
decoder_input_ids = torch.tensor([[0]], dtype=torch.long)
generated_ids = [0]

# Generate tokens one by one
for i in range(max_new_tokens - 1):
# Run decoder for next token prediction
logits = self.exported_decoder.module()(
decoder_input_ids,
encoder_output,
torch.tensor([i], dtype=torch.long)
)

# Get next token
next_token = torch.argmax(logits[:, -1, :], dim=-1).item()
generated_ids.append(next_token)

# Update input for next iteration
decoder_input_ids = torch.tensor([[next_token]], dtype=torch.long)

# Check if EOS token
if next_token == self.config.eos_token_id:
break

return generated_ids
152 changes: 152 additions & 0 deletions tests/models/t5/test_modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from transformers import T5Config, is_torch_available
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
from transformers.testing_utils import (
require_accelerate,
require_sentencepiece,
Expand Down Expand Up @@ -1689,6 +1690,157 @@ def test_compile_static_cache_encoder(self):
logits_compiled = model(**inputs)
torch.testing.assert_close(logits[0][:, -3:, -3], logits_compiled[0][:, -3:, -3], rtol=1e-5, atol=1e-5)

@slow
def test_export_encoder(self):
"""Test exporting T5EncoderModel to torch export format."""
if not is_torch_greater_or_equal_than_2_4:
self.skipTest("This test requires torch >= 2.4 to run.")

from transformers.integrations.executorch import T5EncoderExportableModule

model_id = "google-t5/t5-small"
device = "cpu"
tokenizer = T5Tokenizer.from_pretrained(model_id)
# Create test input
input_text = "Studies have shown that owning a dog is good for you."
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)

# Load model
model = T5EncoderModel.from_pretrained(model_id).to(device=device).eval()

# Get original output for comparison
with torch.no_grad():
original_output = model(input_ids=input_ids).last_hidden_state

encoder_model = T5EncoderExportableModule(model)

# Export the encoder_model
with torch.no_grad():
seq_len_dim = torch.export.Dim("sequence_length", max=4096)

exported_program = torch.export.export(
encoder_model,
(input_ids,),
dynamic_shapes={"input_ids": {1: seq_len_dim}},
strict=True
)

# Test the exported model
with torch.no_grad():
exported_output = exported_program.module()(input_ids)

# Verify outputs are close enough
self.assertTrue(torch.allclose(original_output, exported_output, atol=1e-5))

@slow
def test_export_decoder(self):
"""Test exporting T5 decoder with static cache to torch export format."""
if not is_torch_greater_or_equal_than_2_4:
self.skipTest("This test requires torch >= 2.4 to run.")

from transformers.integrations.executorch import T5DecoderExportableModuleWithStaticCache

model_id = "google-t5/t5-small"

# Configuration for static cache
batch_size = 1
max_cache_len = 123
device = "cpu"

full_model = T5ForConditionalGeneration.from_pretrained(model_id).to(device)
decoder_model = T5DecoderExportableModuleWithStaticCache(full_model, max_cache_len, batch_size).to(device).eval()

# Prepare test inputs
example_decoder_input_ids = torch.tensor([[0]], dtype=torch.long) # Start token
example_cache_position = torch.tensor([0], dtype=torch.long)

# For T5-small, hidden size is 512
example_encoder_hidden_states = torch.zeros((batch_size, 10, 512), dtype=torch.float32)

# Export the model
with torch.no_grad():
encoder_sequence_length_dim = torch.export.Dim("encoder_sequence_length", max=4096)

exported_program = torch.export.export(
decoder_model,
(example_decoder_input_ids, example_encoder_hidden_states, example_cache_position),
dynamic_shapes={
"decoder_input_ids": None,
"encoder_hidden_states": {1: encoder_sequence_length_dim},
"cache_position": None,
},
strict=True
)

# We won't directly verify outputs here as it's complicated with caching,
# but we'll check the export was successful
self.assertIsNotNone(exported_program)

# Verify cache buffers existence and shapes
cache_buffers = [(name, buffer) for name, buffer in exported_program.named_buffers()
if name.startswith("key_cache_") or name.startswith("value_cache_")]

# Verify cache buffers
self.assertTrue(len(cache_buffers) > 0, "No cache buffers found in exported model")
for name, buffer in cache_buffers:
# Verify cache buffers are 3D
self.assertEqual(buffer.shape[2], max_cache_len)

@slow
def test_export_t5_summarization(self):
"""Test composing exported T5 encoder and decoder for summarization."""
if not is_torch_greater_or_equal_than_2_4:
self.skipTest("This test requires torch >= 2.4 to run.")

from transformers.integrations.executorch import T5ExportableModule

batch_size = 1
max_cache_length = 1234
max_hidden_seq_length = 5678
model_id = "google-t5/t5-small"

tokenizer = T5Tokenizer.from_pretrained(model_id)
full_model = T5ForConditionalGeneration.from_pretrained(model_id).eval()
wrapped_model = T5ExportableModule(full_model, batch_size=batch_size, max_hidden_seq_length=max_hidden_seq_length, max_cache_length=max_cache_length)

# Prepare example inputs
example_encoder_input_ids = tokenizer("Test input", return_tensors="pt").input_ids
example_decoder_input_ids = torch.tensor([[0]], dtype=torch.long) # Start token
example_cache_position = torch.tensor([0], dtype=torch.long)

# Get expected hidden size from config
hidden_size = full_model.config.d_model # 512 for T5-small
example_encoder_hidden_states = torch.zeros((batch_size, 10, hidden_size), dtype=torch.float32)

exported_t5 = wrapped_model.export(
encoder_input_ids=example_encoder_input_ids,
decoder_input_ids=example_decoder_input_ids,
encoder_hidden_states=example_encoder_hidden_states,
cache_position=example_cache_position,
)

# Test Summarization with Composed Models
prompts = [
"summarize: Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial "
"reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe "
"theory of relativity is not hard to grasp."
]
input_ids = tokenizer(prompts, return_tensors="pt").input_ids

generated_ids = exported_t5.generate(prompt_token_ids=input_ids, max_new_tokens=max_cache_length)
generated_summary = tokenizer.decode(generated_ids, skip_special_tokens=True)
print(f"Generated summary: {generated_summary}")

# Also run original model for comparison
original_model = T5ForConditionalGeneration.from_pretrained(model_id).eval()
with torch.no_grad():
original_outputs = original_model.generate(input_ids, max_length=50, num_beams=1)
original_summary = tokenizer.decode(original_outputs[0], skip_special_tokens=True)
print(f"Original model summary: {original_summary}")

# Basic verification that we got a reasonable summary
self.assertEqual(generated_summary, original_summary)


@require_torch
class TestAsymmetricT5(unittest.TestCase):
Expand Down