Skip to content

Commit

Permalink
Resolve Errors Preventing CLI from Running Correctly (#88)
Browse files Browse the repository at this point in the history
* Check for 'bitsandbytes' library

Replaced `pass` with `import_module("bitsandbytes")` to actively check for the presence of the 'bitsandbytes' library.
This change ensures that if 'bitsandbytes' is not installed, the correct error is raised and handled by the exception block,
preventing an internal error from the transformers library.

* Replace old 'prefix_prompt' attribute with 'prompt' in the args
  • Loading branch information
gfesatidis authored Sep 8, 2024
1 parent 04d9b34 commit c77ae36
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions transformers_cfg/cli/cli_main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3

import argparse
from importlib import import_module
from transformers_cfg.tokenization.utils import is_tokenizer_supported
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
Expand Down Expand Up @@ -108,7 +109,7 @@ def generate_text(args):
# Load the model with bitsandbytes if 8bit or 4bit flag is set
if args.use_8bit or args.use_4bit:
try:
pass
import_module("bitsandbytes")
except ImportError:
raise ImportError(
"You need to install bitsandbytes to use 8-bit or 4-bit modes. Install it with `pip install bitsandbytes`."
Expand All @@ -130,7 +131,7 @@ def generate_text(args):
model.generation_config.pad_token_id = tokenizer.pad_token_id

inputs = tokenizer(
args.prefix_prompt, add_special_tokens=False, return_tensors="pt", padding=True
args.prompt, add_special_tokens=False, return_tensors="pt", padding=True
)
input_ids = inputs["input_ids"].to(args.device)
attention_mask = inputs["attention_mask"].to(args.device)
Expand Down Expand Up @@ -160,10 +161,10 @@ def generate_text(args):
)

# print prompt first in color
print("\033[92m" + "Prompt:" + args.prefix_prompt + "\033[0m")
print("\033[92m" + "Prompt:" + args.prompt + "\033[0m")

# Store results for optional file output
result = f"Prompt: {args.prefix_prompt}\n\n"
result = f"Prompt: {args.prompt}\n\n"

# Generate without grammar constraints (if contrast mode is enabled)
if not args.no_contrast_mode:
Expand Down

0 comments on commit c77ae36

Please sign in to comment.