Skip to content

Training Improvements: MultipackV2, Statistics, Mock Data #483

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/instructlab/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ class TrainingArgs(BaseModel):

mock_data: Optional[bool] = False
mock_data_len: int = 0
mock_num_samples: int = 0

deepspeed_options: DeepSpeedOptions = Field(
default_factory=lambda: DeepSpeedOptions(
Expand Down Expand Up @@ -228,3 +229,9 @@ class TrainingArgs(BaseModel):
default=False,
description="Whether to use Liger kernels for training.",
)

# TODO(osilkin): Create a better API for this, should not merge into library this way
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know Fynn has been working on "SDK-ifying" the sampler specifically @RobotSail , maybe we should sync on this with the training team

use_multipack_v2: bool = Field(
default=False,
description="Use the MultipackV2 sampler which balances batches based on computational cost. Does not support Padding transformers.",
)
106 changes: 88 additions & 18 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@
from instructlab.training.multipack_sampler import (
find_packing_max_batch_len_and_grad_accum,
)
from instructlab.training.multipack_sampler_v2 import (
find_packing_max_batch_len_and_grad_accum as find_packing_max_batch_len_and_grad_accum_v2,
)
from instructlab.training.setup_accelerator import setup_accelerator
from instructlab.training.token_dataset import setup_dataloader, setup_dataset
from instructlab.training.tokenizer_utils import setup_tokenizer
Expand Down Expand Up @@ -379,9 +382,15 @@ def train(
else None
)

# variables for tracking statistics
global_grad_norm = None
stats_momentum = 0.999 # average strength is effectively 1/1000
avg_throughput = 0.0
avg_time_per_step = 0.0
num_batches = None

for epoch in range(args.current_epoch, args.num_epochs):
if args.sampler in ("multipack"):
if args.sampler in ("multipack", "multipack_v2"):
train_loader.batch_sampler.set_epoch(epoch)
elif args.sampler in ("distributed"):
train_loader.sampler.set_epoch(epoch)
Expand All @@ -393,6 +402,9 @@ def train(

# blast through the batches in the train loader up to the last step within the epoch.
for batch in train_loader:
if not num_batches:
num_batches = len(train_loader)

if global_step <= args.last_step:
# in the case of resuming, last_step > 0
global_step += 1
Expand Down Expand Up @@ -445,6 +457,28 @@ def train(
if local_rank == 0:
elapsed_time = time.time() - start
overall_throughput = args.samples_per_gpu * world_size / elapsed_time

# moving averages
avg_throughput = (
stats_momentum * avg_throughput
+ (1 - stats_momentum) * overall_throughput
)
avg_time_per_step = (
stats_momentum * avg_time_per_step
+ (1 - stats_momentum) * elapsed_time
)

# bias-correction so initial values dont tend towards 0
corrected_avg_throughput = avg_throughput / (
1 - (stats_momentum**global_step)
)
corrected_avg_time_per_step = avg_time_per_step / (
1 - (stats_momentum**global_step)
)

# now we can estimate the estimated epoch length
length_per_epoch = corrected_avg_time_per_step * num_batches

current_lr = lr_scheduler.get_last_lr()[0]
cuda_mem_allocated = torch.cuda.memory_allocated() / (1024**3)
cuda_malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"]
Expand All @@ -468,6 +502,13 @@ def train(
"step": global_step,
"rank": torch.distributed.get_rank(),
"overall_throughput": overall_throughput,
"avg_overall_throughput": avg_throughput,
"corrected_avg_overall_throughput": corrected_avg_throughput,
"elapsed_time": elapsed_time,
"avg_elapsed_time": avg_time_per_step,
"corrected_avg_elapsed_time": corrected_avg_time_per_step,
"length_per_epoch": length_per_epoch / 3600,
"num_batches": num_batches,
"lr": current_lr,
"cuda_mem_allocated": cuda_mem_allocated,
"cuda_malloc_retries": cuda_malloc_retries,
Expand Down Expand Up @@ -580,19 +621,35 @@ def main(args):
args.data_path,
mock=args.mock_data,
mock_len=args.mock_len,
mock_num_samples=args.mock_num_samples,
)

try:
packing_max_batch_len, grad_accum = find_packing_max_batch_len_and_grad_accum(
num_gpus=torch.distributed.get_world_size(),
avg_sample_len=dataset.get_lengths().mean(),
effective_batch_size=args.effective_batch_size,
max_batch_len_per_gpu=args.max_batch_len,
is_padding=not (args.use_dolomite or flash_enabled),
dataset=dataset,
seed=args.seed,
)
args.sampler = "multipack"
if args.use_multipack_v2:
packing_max_batch_len, grad_accum = (
find_packing_max_batch_len_and_grad_accum_v2(
num_gpus=torch.distributed.get_world_size(),
avg_sample_len=dataset.get_lengths().mean(),
effective_batch_size=args.effective_batch_size,
max_batch_len_per_gpu=args.max_batch_len,
dataset=dataset,
seed=args.seed,
)
)
args.sampler = "multipack_v2"
else:
packing_max_batch_len, grad_accum = (
find_packing_max_batch_len_and_grad_accum(
num_gpus=torch.distributed.get_world_size(),
avg_sample_len=dataset.get_lengths().mean(),
effective_batch_size=args.effective_batch_size,
max_batch_len_per_gpu=args.max_batch_len,
is_padding=not (args.use_dolomite or flash_enabled),
dataset=dataset,
seed=args.seed,
)
)
args.sampler = "multipack"
except RuntimeError as e:
if os.environ["LOCAL_RANK"] == "0":
print(f"\033[38;5;120m{e}\033[0m")
Expand Down Expand Up @@ -640,6 +697,11 @@ def main(args):
seed=args.seed,
)

assert (
not args.use_multipack_v2
or (args.use_multipack_v2 and args.sampler) == "multipack_v2"
), "multipack_v2 was enabled but is not selected"

if args.local_rank == 0:
metric_logger.log_sync(
{
Expand Down Expand Up @@ -683,6 +745,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
"""
Wrapper around the main training job that calls torchrun.
"""
# TODO(osilkin): add a check here for multpackv2 and a padding transformers
check_valid_train_args(train_args)

# switch out generic tmpl for legacy tmpl if requested
Expand Down Expand Up @@ -746,8 +809,10 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:

if train_args.mock_data:
command.append("--mock_data")
if train_args.mock_len:
command.append(f"--mock_len={train_args.mock_len}")
if train_args.mock_data_len:
command.append(f"--mock_len={train_args.mock_data_len}")
if train_args.mock_num_samples:
command.append(f"--mock_num_samples={train_args.mock_num_samples}")

if train_args.use_dolomite:
command.append("--use_dolomite")
Expand Down Expand Up @@ -805,11 +870,10 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:

# FSDP Options
if train_args.fsdp_options.cpu_offload_params:
command.extend(
[
"--cpu_offload_params_fsdp",
]
)
command.append("--cpu_offload_params_fsdp")

if train_args.use_multipack_v2:
command += ["--use_multipack_v2"]

# specify the sharding strategy
command.append(
Expand Down Expand Up @@ -933,6 +997,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--mock_data", action="store_true")
parser.add_argument("--mock_len", type=int, default=2600)
parser.add_argument("--mock_num_samples", type=int, default=92_000)
parser.add_argument(
"--distributed_training_framework",
type=str,
Expand Down Expand Up @@ -1008,6 +1073,11 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
action="store_true",
help="Use Liger kernels for training.",
)
parser.add_argument(
"--use_multipack_v2",
action="store_true",
help="Use the MultipackV2 algorithm for packing batches. This is more optimal but does not support Transformers which require Padding.",
)
args = parser.parse_args()
set_random_seed(args.seed)
main(args)
Expand Down
1 change: 1 addition & 0 deletions src/instructlab/training/multipack_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ def generate_batches(self, set_stats=False):

return batches

# TODO(osilkin): cache the length here
def __iter__(self):
batches = self.generate_batches(set_stats=True)
return iter(batches)
Expand Down
Loading
Loading