-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
beam search doesn't work with transformers_cfg #9
Comments
Thanks for raising this issue, the support for beam search is yet in progress. The error message is below
|
Here, I describe how to integrate support for beam search with grammar-constrained decoding in case we have volunteer wants to contribute :) At present, our library utilizes a While effective for various decoding/sampling methods, it doesn't suit constrained beam search. The incompatibility of the constrained logit processor with beam search is complex and relates to the mechanics of beam search itself. However, this detail is not central to this feature, as our focus is on employing the Credit goes to @chanwkimlab for developing the constrained beam search and providing a robust abstraction along with a comprehensive blog post: https://huggingface.co/blog/constrained-beam-search The procedure involves:
class GrammarConstraint(Constraint):
def __init__(self, token_ids: List[int]):
super(Constraint, self).__init__()
...
def advance(self):
...
def does_advance(self, token_id: int):
...
def update(self, token_id: int):
...
def reset(self):
self.completed = False
self.fulfilled_idx = 0
def remaining(self):
# For grammar constrained decoding, determining the exact number of remaining tokens may be challenging, but it should not pose a significant issue. Here are some example implementation of That's it ! |
Hello! Is this still an active issue, or does a workaround have been found? I can give a shot at coding the GrammarConstraint class |
Hey, the feature is not yet done. Go ahead and I will be happy to merge it :)
On 13 Aug 2024, at 20:10, Hichem Ammar Khodja ***@***.***> wrote:
Hello! Is this still an active issue, or does a workaround have been found?
I can give a shot at coding the GrammarConstraint class
—
Reply to this email directly, view it on GitHub<#9 (comment)>, or unsubscribe<https://github.com/notifications/unsubscribe-auth/AMXLMUAG6NVDGAQ64VPH5WLZRIAVRAVCNFSM6AAAAABDQTPTD2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEOBWGIYTSMJSGE>.
You are receiving this because you commented.Message ID: ***@***.***>
|
Hello, unfortunately I couldn't make it work, this constraint feature lacks documentation and it's difficult to understand how it works behind the scenes. When coding, I tried to follow the same format as the constraints found in the transformers library. transformers version: 4.44.0 Here is my best attempt: from transformers.generation.beam_constraints import Constraint
from transformers_cfg.grammar_utils import IncrementalTokenRecognizer, IncrementalGrammarConstraint
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
class GrammarConstraint(Constraint):
def __init__(self, token_recognizer : IncrementalTokenRecognizer):
super(Constraint, self).__init__()
self.token_recognizer = token_recognizer
self.current_state = self.token_recognizer.string_recognizer.get_initial_parsing_state()
self.valid_tokens = self.token_recognizer.get_next_token_acceptance(self.current_state, device='cpu')
self.completed = False
self.seqlen = float('inf')
self.tokens = []
@property
def text(self):
return self.token_recognizer.tokenizer.decode(self.tokens)
def advance(self):
# Return the next set of tokens that would be accepted by the current grammar state
if self.completed:
return []
acceptance = self.valid_tokens
return acceptance.nonzero(as_tuple=False).squeeze(-1).tolist()
def does_advance(self, token_id: int):
# Check if the given token_id is accepted by the current grammar state
acceptance = self.valid_tokens
return acceptance[token_id]
def update(self, token_id: int):
# Update the state with the given token_id and return the progress indicators
if self.does_advance(token_id):
new_state = self.token_recognizer._update_state_with_token_id(token_id, self.current_state)
self.current_state = new_state
stepped = True
completed = not bool(new_state.stacks) # If stacks are empty, the constraint is completed
self.tokens.append(token_id)
if not completed:
self.valid_tokens = self.token_recognizer.get_next_token_acceptance(self.current_state, device='cpu')
reset = False
else:
# The token_id was not accepted, reset the state
self.reset()
stepped = False
completed = False
reset = True
self.completed = completed
return stepped, completed, reset
def reset(self):
# Reset the state of this constraint to its initialization
self.current_state = self.token_recognizer.string_recognizer.get_initial_parsing_state()
self.valid_tokens = self.token_recognizer.get_next_token_acceptance(self.current_state, device='cpu')
self.completed = False
self.tokens.clear()
def remaining(self):
# Return the number of remaining steps; this is more complex for a grammar constraint
# and might not be easily quantifiable. For simplicity, we return 1 if not completed.
return 0 if self.completed else 1
def copy(self, stateful=False):
# Create a new instance of this constraint
new_constraint = GrammarConstraint(
self.token_recognizer
)
if stateful:
new_constraint.current_state = self.current_state
new_constraint.valid_tokens = self.valid_tokens
new_constraint.completed = self.completed
new_constraint.tokens = self.tokens.copy()
return new_constraint
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
if __name__ == "__main__":
# Detect if GPU is available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model_id = "gpt2"
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
# Load json grammar
with open("tuples.ebnf", "r") as file:
grammar_str = file.read()
token_recognizer = IncrementalGrammarConstraint(grammar_str, "root", tokenizer, unicode=True)
grammar = GrammarConstraint(token_recognizer)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, torch_dtype=torch.bfloat16) # Load model to defined device
model.generation_config.pad_token_id = model.generation_config.eos_token_id
# Generate
prefix1 = "Tuples:"
input_ids = tokenizer([prefix1], add_special_tokens=False, return_tensors="pt", padding=True)["input_ids"].to(device)
max_new_tokens = 50
# grammar.seqlen = max_new_tokens
output = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
constraints=[grammar],
num_beams=3,
do_sample=False
)
# decode output
generations = tokenizer.batch_decode(output, skip_special_tokens=True)
print(generations)
Here is the content of tuples.ebnf:
|
@HichemAK Thanks for your effort! After diving deeper into beam search, I found that the implementation of constrained beam search in HF is quite convoluted and too closely tied to existing constraints, making it not general enough. Trying to implement it directly is indeed not the best way. The results coding will be very ugly and ineffient. It might be better to avoid that approach and work directly with beam search, but that would require modifying the HF codebase. I’ve sketched out how I plan to implement this beam search. For those interested, feel free to check it out. I’ll likely start working on it myself in the next few days. |
The text was updated successfully, but these errors were encountered: