Skip to content

Commit

Permalink
Add --max_batch_size and --batch_size auto:N
Browse files Browse the repository at this point in the history
  • Loading branch information
gakada committed Jun 11, 2023
1 parent b21c8f3 commit 8cec82b
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 59 deletions.
78 changes: 41 additions & 37 deletions lm_eval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ def set_cache_hook(self, cache_hook):


class BaseLM(LM):
def __init__(self):
super().__init__()
self.batch_schedule = 1
self.batch_sizes = {}
self.max_batch_size = 512

@property
@abstractmethod
def eot_token_id(self):
Expand Down Expand Up @@ -167,6 +173,26 @@ def _model_call(self, inps):
"""
pass

def _detect_batch_size(self, requests=None, pos=0):
if requests:
_, context_enc, continuation_enc = requests[pos]
max_length = len((context_enc + continuation_enc)[-(self.max_length + 1) :][:-1])
else:
max_length = self.max_length

# if OOM, then halves batch_size and tries again
@find_executable_batch_size(starting_batch_size=self.max_batch_size)
def forward_batch(batch_size):
test_batch = torch.ones((batch_size, max_length), device=self.device).long()
for _ in range(5):
_ = F.log_softmax(self._model_call(test_batch), dim=-1).cpu()
return batch_size

batch_size = forward_batch()
utils.clear_torch_cache()

return batch_size

# subclass must implement properties vocab_size, eot_token_id, max_gen_toks, batch_size, device, max_length.
# TODO: enforce this somehow

Expand Down Expand Up @@ -202,19 +228,7 @@ def loglikelihood_rolling(self, requests):
if self.batch_size == "auto":
# using rolling window with maximum context
print("Passed argument batch_size = auto. Detecting largest batch size")

@find_executable_batch_size(
starting_batch_size=512
) # if OOM, then halves batch_size and tries again
def forward_batch(batch_size):
test_batch = torch.ones(
(batch_size, self.max_length), device=self.device
).long()
for _ in range(5):
_ = F.log_softmax(self._model_call(test_batch), dim=-1).cpu()
return batch_size

batch_size = forward_batch()
batch_size = self._detect_batch_size()
print(f"Determined Largest batch size: {batch_size}")
adaptive_batch_size = batch_size

Expand Down Expand Up @@ -267,34 +281,24 @@ def _collate(x):

re_ord = utils.Reorderer(requests, _collate)

reordered_requests = re_ord.get_reordered()
n_reordered_requests = len(reordered_requests)

# automatic (variable) batch size detection for vectorization
# pull longest context sample from request
if len(re_ord.get_reordered()) > 0:
_, context_enc, continuation_enc = re_ord.get_reordered()[0]
max_context = len((context_enc + continuation_enc)[-(self.max_length + 1) :][:-1])
if (self.batch_size == 'auto'):

if override_bs is None:
print('Passed argument batch_size = auto. Detecting largest batch size')
@find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again
def forward_batch(batch_size):
test_batch = torch.ones((batch_size, max_context), device=self.device).long()
for _ in range(5):
out = F.log_softmax(self._model_call(test_batch), dim = -1).cpu()
return batch_size

batch_size = forward_batch()
print(f"Determined largest batch size: {batch_size}")
adaptive_batch_size = batch_size

else:
adaptive_batch_size = override_bs
else:
adaptive_batch_size = 0 if override_bs is None else override_bs
def _batch_scheduler(pos):
sched = pos // int(n_reordered_requests / self.batch_schedule)
if sched in self.batch_sizes:
return self.batch_sizes[sched]
print(f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size")
self.batch_sizes[sched] = self._detect_batch_size(reordered_requests, pos)
print(f"Determined largest batch size: {self.batch_sizes[sched]}")
return self.batch_sizes[sched]

for chunk in utils.chunks(
tqdm(re_ord.get_reordered(), disable=disable_tqdm),
self.batch_size if self.batch_size != "auto" else adaptive_batch_size,
tqdm(reordered_requests, disable=disable_tqdm),
n=self.batch_size if self.batch_size != "auto" else override_bs if override_bs is not None else 0,
fn=_batch_scheduler if self.batch_size == "auto" and n_reordered_requests > 0 else None,
):
inps = []
cont_toks_list = []
Expand Down
8 changes: 6 additions & 2 deletions lm_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def simple_evaluate(
tasks=[],
num_fewshot=0,
batch_size=None,
max_batch_size=None,
device=None,
no_cache=False,
limit=None,
Expand All @@ -37,8 +38,10 @@ def simple_evaluate(
List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
:param num_fewshot: int
Number of examples in few-shot context
:param batch_size: int, optional
:param batch_size: int or str, optional
Batch size for model
:param max_batch_size: int, optional
Maximal batch size to try with automatic batch size detection
:param device: str, optional
PyTorch device (e.g. "cpu" or "cuda:0") for running models
:param no_cache: bool
Expand Down Expand Up @@ -67,7 +70,7 @@ def simple_evaluate(
if model_args is None:
model_args = ""
lm = lm_eval.models.get_model(model).create_from_arg_string(
model_args, {"batch_size": batch_size, "device": device}
model_args, {"batch_size": batch_size, "max_batch_size": max_batch_size, "device": device}
)
else:
assert isinstance(model, lm_eval.base.LM)
Expand Down Expand Up @@ -106,6 +109,7 @@ def simple_evaluate(
"model_args": model_args,
"num_fewshot": num_fewshot,
"batch_size": batch_size,
"batch_sizes": list(lm.batch_sizes.values()),
"device": device,
"no_cache": no_cache,
"limit": limit,
Expand Down
23 changes: 7 additions & 16 deletions lm_eval/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from tqdm import tqdm

from transformers import BatchEncoding
from accelerate import find_executable_batch_size

from lm_eval import utils
from lm_eval.base import BaseLM
Expand Down Expand Up @@ -76,6 +75,7 @@ def __init__(
subfolder: Optional[str] = None,
revision: Optional[str] = "main",
batch_size: Optional[Union[int, str]] = 1,
max_batch_size: Optional[int] = 512,
max_gen_toks: Optional[int] = 256,
max_length: Optional[int] = None,
add_special_tokens: Optional[bool] = None,
Expand Down Expand Up @@ -172,10 +172,13 @@ def __init__(
), "Evaluating causal models with `add_special_tokens=True` is currently not supported."

# setup for automatic batch size detection
if batch_size == "auto":
self._batch_size = batch_size
if str(batch_size).startswith("auto"):
batch_size = batch_size.split(":")
self._batch_size = batch_size[0]
self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1
else:
self._batch_size = int(batch_size)
self.max_batch_size = max_batch_size

self._max_gen_toks = max_gen_toks
self._max_length = max_length
Expand Down Expand Up @@ -411,19 +414,7 @@ def _collate(x):
if self.batch_size == "auto":
# using rolling window with maximum context
print("Passed argument batch_size = auto. Detecting largest batch size")

@find_executable_batch_size(
starting_batch_size=512
) # if OOM, then halves batch_size and tries again
def forward_batch(batch_size):
test_batch = torch.ones(
(batch_size, self.max_length), device=self.device
).long()
for _ in range(5):
_ = F.log_softmax(self._model_call(test_batch), dim=-1).cpu()
return batch_size

batch_size = forward_batch()
batch_size = self._detect_batch_size()
print(f"Determined Largest batch size: {batch_size}")
adaptive_batch_size = batch_size

Expand Down
12 changes: 9 additions & 3 deletions lm_eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import fnmatch
from typing import List, Union

import gc
import torch

from omegaconf import OmegaConf
Expand Down Expand Up @@ -64,11 +65,11 @@ def join_iters(iters):
yield from iter


def chunks(iter, n):
def chunks(iter, n=0, fn=None):
arr = []
for x in iter:
for i, x in enumerate(iter):
arr.append(x)
if len(arr) == n:
if len(arr) == (fn(i) if fn else n):
yield arr
arr = []

Expand Down Expand Up @@ -283,3 +284,8 @@ def run_task_tests(task_list: List[str]):
raise ValueError(
f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}"
)


def clear_torch_cache():
gc.collect()
torch.cuda.empty_cache()
6 changes: 5 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def parse_args():
parser.add_argument("--provide_description", action="store_true")
parser.add_argument("--num_fewshot", type=int, default=0)
parser.add_argument("--batch_size", type=str, default=None)
parser.add_argument("--max_batch_size", type=int, default=None,
help="Maximal batch size to try with --batch_size auto")
parser.add_argument("--device", type=str, default=None)
parser.add_argument("--output_path", default=None)
parser.add_argument("--limit", type=float, default=None,
Expand Down Expand Up @@ -60,6 +62,7 @@ def main():
tasks=task_names,
num_fewshot=args.num_fewshot,
batch_size=args.batch_size,
max_batch_size=args.max_batch_size,
device=args.device,
no_cache=args.no_cache,
limit=args.limit,
Expand All @@ -78,9 +81,10 @@ def main():
with open(args.output_path, "w") as f:
f.write(dumped)

batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))
print(
f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, "
f"num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}"
f"num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
)
print(evaluator.make_table(results))

Expand Down

0 comments on commit 8cec82b

Please sign in to comment.