Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 29 additions & 153 deletions examples/pytorch/continuous_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import datasets
import torch
from tokenizers.decoders import DecodeStream

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
Expand All @@ -18,7 +17,7 @@

# --- Common Setup ---
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-3b-Instruct", attn_implementation="sdpa", torch_dtype=torch.float16, device_map="auto"
"meta-llama/Llama-3.2-3b-Instruct", attn_implementation="sdpa_paged", torch_dtype=torch.bfloat16, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(
"meta-llama/Llama-3.2-3b-Instruct", torch_dtype=torch.float16, padding_side="left"
Expand All @@ -33,7 +32,7 @@

# Configure generation parameters
generation_config = GenerationConfig(
max_new_tokens=50,
max_new_tokens=16,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
# Add other parameters like temperature, top_k etc. if needed
Expand Down Expand Up @@ -61,6 +60,28 @@ def tokenize_function(examples):
tokenized_test_prompts = tokenizer(_TEST_PROMPTS, truncation=True, max_length=512)
simple_batch_inputs = list(tokenized_test_prompts["input_ids"])

model.config.attn_implementation = "sdpa"
start_time_simple = time.time()
outputs = model.generate(
input_ids=tokenizer(
_TEST_PROMPTS, truncation=True, max_length=512, padding=True, return_tensors="pt"
).input_ids.to(model.device),
generation_config=GenerationConfig(
max_new_tokens=25, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id
),
# You can pass request-specific overrides here, e.g., max_new_tokens=100
)
end_time_simple = time.time()
print(f"\nSimple batch generation took: {end_time_simple - start_time_simple:.2f} seconds")

print("\nResults from simple generate_batch:")
for i, request in enumerate(outputs):
output_text = tokenizer.decode(request, skip_special_tokens=False)
print("-" * 20)
print(f"Result for Request {request}:")
print(f" Output: {output_text}")
print("-" * 20)
print("--- Finished Simple Batch Generation Example ---\n\n")

# --- Example 1: Simple Version using generate_batch ---
print("--- Running Simple Batch Generation Example ---")
Expand All @@ -81,157 +102,12 @@ def tokenize_function(examples):

# Decode and print results
print("\nResults from simple generate_batch:")
for i, output_ids in enumerate(batch_outputs):
input_text = tokenizer.decode(simple_batch_inputs[i], skip_special_tokens=False)
output_text = tokenizer.decode(output_ids, skip_special_tokens=False)
for request in batch_outputs:
input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=False)
output_text = tokenizer.decode(batch_outputs[request].static_outputs, skip_special_tokens=False)
print("-" * 20)
print(f"Result for Request {i}:")
# print(f" Input: {input_text}")
print(f"Result for Request {request}:")
print(f" Input: {input_text}")
print(f" Output: {output_text}")
print("-" * 20)
print("--- Finished Simple Batch Generation Example ---\n\n")

outputs = []

model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-3b-Instruct", attn_implementation="sdpa", torch_dtype=torch.float16, device_map="auto"
)

print("--- Running Simple Generation for comparison ---")
# tokenized_datasets = train_dataset.map(tokenize_function, batched=True)
# simple_inputs = [torch.tensor(item["input_ids"], device=device) for item in tokenized_test_prompts]
tokenized_test_prompts = tokenizer(_TEST_PROMPTS, truncation=True, max_length=512)
simple_inputs = [torch.tensor(item, device=device) for item in tokenized_test_prompts["input_ids"]]


padded_inputs = tokenizer.pad({"input_ids": list(tokenized_test_prompts["input_ids"])}, return_tensors="pt").to(device)


start_time_simple = time.time()
outputs = model.generate(
**padded_inputs,
generation_config=generation_config,
do_sample=False,
)
end_time_simple = time.time()

print(f"generation config: {generation_config}")

print(f"\nSimple generation took: {end_time_simple - start_time_simple:.2f} seconds")

print("\nResults from simple generate:")
for i, output_ids in enumerate(outputs):
input_text = tokenizer.decode(simple_inputs[i], skip_special_tokens=False)
# The output_ids from batch generation include the input tokens, skip them for decoding
# We need to know the length of the input to slice the output correctly
input_length = len(simple_inputs[i])
# Slice the output ids to get only the generated part
generated_ids = output_ids[input_length:]
output_text = tokenizer.decode(generated_ids, skip_special_tokens=False)
print("-" * 20)
print(f"Result for Request {i}:")
# print(f" Input: {input_text}")
print(f" Output: {output_text}")
print("-" * 20)

print("--- Finished Simple Generation Example ---\n\n")

# --- Example 2: Streaming Version using ContinuousBatchingManager ---
print("--- Running Streaming Continuous Batching Example ---")

manager = model.init_continuous_batching(generation_config=generation_config, streaming=True)

manager.start()

# Doing it here with one request to avoid interleaving outputs
req_id = manager.add_request(simple_batch_inputs[0])

request_streams = {}

first_token = True
for output in manager:
req_id = output["request_id"]
if req_id is None:
continue
if first_token:
print(f"Request {req_id} started")
first_token = False
if req_id not in request_streams:
request_streams[req_id] = DecodeStream(skip_special_tokens=False)
next_token = request_streams[req_id].step(tokenizer._tokenizer, output["next_token"])
print(f"{next_token}", end="")
if output["status"] in ["finished", "failed"]:
print(f"\nRequest {req_id} {output['status']}")
del request_streams[req_id]
break

manager.stop(block=True, timeout=10)

# --- Example 3: Involved Performant Version using ContinuousBatchingManager ---
print("--- Running Involved Continuous Batching Example ---")

model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-3b-Instruct", attn_implementation="sdpa", torch_dtype=torch.float16, device_map="auto"
)

# Prepare data for the involved example (using a larger dataset)
involved_dataset = datasets.load_dataset("imdb", split="test")
involved_dataset = involved_dataset.select(range(100)) # Use 100 examples
tokenized_involved_datasets = involved_dataset.map(tokenize_function, batched=True)
# Extract input_ids
requests_data = [{"input_ids": item["input_ids"]} for item in tokenized_involved_datasets]

# 1. Initialize the manager
manager = model.init_continuous_batching(generation_config=generation_config)

# Optional: Provide initial shapes to help cache calculation (if using many similar length prompts)
# manager.add_initial_prompts([req['input_ids'] for req in requests_data[:5]])

# 2. Start the background generation thread
manager.start()

submitted_requests = {}
results = {}
start_time_involved = time.time() # Start timing before submission

for i, req_data in enumerate(requests_data):
try:
req_id = manager.add_request(req_data["input_ids"], request_id=f"req_{i}")
submitted_requests[req_id] = {"input": tokenizer.decode(req_data["input_ids"])}
print(f"Submitted request {req_id}")
except Exception as e:
print(f"Failed to submit request {i}: {e}")


# 3. Retrieve results
finished_count = 0
while finished_count < len(submitted_requests) and manager.is_running():
result = manager.get_result(timeout=1.0)
if result:
req_id = result["request_id"]
finished_count += 1
results[req_id] = result
output_text = tokenizer.decode(result["output_ids"], skip_special_tokens=True)
print("-" * 20)
print(f"Result for {req_id} (Status: {result['status']}):")
# print(f" Input: {submitted_requests[req_id]['input'][:100]}...") # Optional: print input
print(f" Output: {output_text}")
print("-" * 20)

end_time_involved = time.time() # End timing after retrieval loop
print(
f"\nInvolved continuous batching took: {end_time_involved - start_time_involved:.2f} seconds (includes submission delays)"
)

print(f"Total submitted: {len(submitted_requests)}")
print(f"Total results received: {len(results)}")

print("Stopping the manager...")
manager.stop(block=True, timeout=10)

print("Manager stopped.")

# You can now process the `results` dictionary which contains
# {"request_id": ..., "output_ids": ..., "status": ...} for each finished request.

print("--- Finished Advanced Continuous Batching Example ---")
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@
"pytest-rich",
"libcst",
"rich",
"opentelemetry-api",
"opentelemetry-sdk",
"opentelemetry-exporter-otlp",
]


Expand Down
3 changes: 3 additions & 0 deletions src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,7 @@
"pytest-rich": "pytest-rich",
"libcst": "libcst",
"rich": "rich",
"opentelemetry-api": "opentelemetry-api",
"opentelemetry-sdk": "opentelemetry-sdk",
"opentelemetry-exporter-otlp": "opentelemetry-exporter-otlp",
}
6 changes: 4 additions & 2 deletions src/transformers/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@
"validate_stopping_criteria",
"StopStringCriteria",
]
_import_structure["continuous_batching"] = [ "ContinuousMixin",]
_import_structure["continuous_batching"] = [
"ContinuousMixin",
]
_import_structure["utils"] = [
"GenerationMixin",
"GreedySearchEncoderDecoderOutput",
Expand Down Expand Up @@ -214,6 +216,7 @@
EarlyExitCandidateGenerator,
PromptLookupCandidateGenerator,
)
from .continuous_batching import ContinuousMixin
from .logits_process import (
AlternatingCodebooksLogitsProcessor,
ClassifierFreeGuidanceLogitsProcessor,
Expand Down Expand Up @@ -275,7 +278,6 @@
SampleDecoderOnlyOutput,
SampleEncoderDecoderOutput,
)
from .continuous_batching import ContinuousMixin
from .watermarking import (
BayesianDetectorConfig,
BayesianDetectorModel,
Expand Down
Loading