Skip to content

Commit

Permalink
replace gpt2 by Mistral
Browse files Browse the repository at this point in the history
update version to 0.1.3
  • Loading branch information
Saibo-creator committed Feb 27, 2024
1 parent 75f15f6 commit 1f3e987
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 12 deletions.
18 changes: 13 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,26 @@ pip install git+https://github.com/epfl-dlab/transformers-CFG.git
The below example can be found in `examples/generate_json.py`

```python

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor

if __name__ == "__main__":
# Detect if GPU is available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model_id = "mistralai/Mistral-7B-v0.1"

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained("gpt2")

model = AutoModelForCausalLM.from_pretrained(model_id).to(
device
) # Load model to defined device
model.generation_config.pad_token_id = model.generation_config.eos_token_id

# Load json grammar
with open("examples/grammars/json.ebnf", "r") as file:
Expand All @@ -54,9 +64,7 @@ if __name__ == "__main__":

output = model.generate(
input_ids,
do_sample=False,
max_length=50,
num_beams=2,
logits_processor=[grammar_processor],
repetition_penalty=1.1,
num_return_sequences=1,
Expand Down
8 changes: 6 additions & 2 deletions examples/generate_cIE.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model_id = "mistralai/Mistral-7B-v0.1"

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained("gpt2").to(

model = AutoModelForCausalLM.from_pretrained(model_id).to(
device
) # Load model to defined device
model.generation_config.pad_token_id = model.generation_config.eos_token_id

# Load grammar
with open("examples/grammars/cIE.ebnf", "r") as file:
Expand Down
8 changes: 6 additions & 2 deletions examples/generate_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model_id = "mistralai/Mistral-7B-v0.1"

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained("gpt2").to(

model = AutoModelForCausalLM.from_pretrained(model_id).to(
device
) # Load model to defined device
model.generation_config.pad_token_id = model.generation_config.eos_token_id

# Load grammar
with open("examples/grammars/json.ebnf", "r") as file:
Expand Down
8 changes: 6 additions & 2 deletions examples/generate_json_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model_id = "mistralai/Mistral-7B-v0.1"

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained("gpt2").to(

model = AutoModelForCausalLM.from_pretrained(model_id).to(
device
) # Load model to defined device
model.generation_config.pad_token_id = model.generation_config.eos_token_id

# Load grammar
with open("examples/grammars/json_arr.ebnf", "r") as file:
Expand Down
2 changes: 1 addition & 1 deletion transformers_cfg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

setup_logging()

__version__ = "0.1.2"
__version__ = "0.1.3"

0 comments on commit 1f3e987

Please sign in to comment.