Skip to content

Commit

Permalink
Merge pull request EleutherAI#579 from SONG-WONHO/master
Browse files Browse the repository at this point in the history
Add - 4bit-related args
  • Loading branch information
haileyschoelkopf authored Jun 14, 2023
2 parents 9d06c95 + 990bc54 commit 137d0c4
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions lm_eval/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def __init__(
load_in_4bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False,
gptq_use_triton: Optional[bool] = False,
bnb_4bit_quant_type: Optional[str] = None,
bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None,
):
"""Initializes a HuggingFace `AutoModel` and `AutoTokenizer` for evaluation.
Args:
Expand Down Expand Up @@ -152,6 +154,13 @@ def __init__(
If True, will trust the remote code when loading the model.
gptq_use_triton (bool, optional, defaults to False):
Use Triton for GPTQ inference.
bnb_4bit_quant_type (str, optional, defaults to None):
The quantization type to use for BnB 4bit quantization. See:
https://github.com/huggingface/transformers/blob/main/src/transformers/utils/quantization_config.py#L77
bnb_4bit_compute_dtype (Union[str, torch.dtype], optional, defaults to None):
The compute dtype to use for BnB 4bit quantization. See:
https://github.com/huggingface/transformers/blob/main/src/transformers/utils/quantization_config.py#L74
"""
super().__init__()

Expand Down Expand Up @@ -215,6 +224,8 @@ def __init__(
gptq_use_triton=gptq_use_triton,
load_in_8bit=load_in_8bit,
load_in_4bit=load_in_4bit,
bnb_4bit_quant_type=bnb_4bit_quant_type,
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
**model_kwargs,
)
# note: peft_path can be different than pretrained model path
Expand Down Expand Up @@ -256,6 +267,8 @@ def _create_auto_model(
trust_remote_code: Optional[bool] = False,
torch_dtype: Optional[Union[str, torch.dtype]] = None,
gptq_use_triton: Optional[bool] = False,
bnb_4bit_quant_type: Optional[str] = None,
bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None,
) -> transformers.AutoModel:
"""Returns a pre-trained pytorch model from a pre-trained model configuration."""
if not quantized:
Expand All @@ -264,6 +277,9 @@ def _create_auto_model(
model_kwargs = {}
if transformers.__version__ >= "4.30.0":
model_kwargs["load_in_4bit"] = load_in_4bit
if load_in_4bit:
model_kwargs["bnb_4bit_quant_type"] = bnb_4bit_quant_type
model_kwargs["bnb_4bit_compute_dtype"] = getattr(torch, bnb_4bit_compute_dtype)
model = self.AUTO_MODEL_CLASS.from_pretrained(
pretrained,
revision=revision + ("/" + subfolder if subfolder is not None else ""),
Expand Down

0 comments on commit 137d0c4

Please sign in to comment.