From c884f134eeff4fa4a4704b6dc2b7ee8856f12101 Mon Sep 17 00:00:00 2001
From: Charlie Doern <cdoern@redhat.com>
Date: Tue, 27 Aug 2024 10:36:44 -0400
Subject: [PATCH] add support for CPU and MPS

do not use distributed when not available, instead use CPU or MPS.

This entails a few changes:

--device is now a valid flag to the library since `ilab` can pass CPU, MPS, or default to cuda
when using CPU or MPS, do not initialize DS, instead put the model on the device and initialize `Adafactor` optimizer which is more efficient and than Adam based one
inside of `train` add logic for handling if torch.cuda.is_available and torch.distributed.is_initialized() we dont use distributed torch on consumer systems
the train loop needs some custom step and loss logic for a LlamaForCausalLM model, add that in
when using CPU or MPS we are always world_size == 1 and local_rank == 0

Signed-off-by: Charlie Doern <cdoern@redhat.com>
---
 src/instructlab/training/__init__.py          |   6 +-
 src/instructlab/training/main_ds.py           | 251 ++++++++++++------
 src/instructlab/training/multipack_sampler.py |  10 +-
 src/instructlab/training/token_dataset.py     |   2 +
 src/instructlab/training/utils.py             |  36 +--
 5 files changed, 209 insertions(+), 96 deletions(-)

diff --git a/src/instructlab/training/__init__.py b/src/instructlab/training/__init__.py
index 151989a6..3537ba87 100644
--- a/src/instructlab/training/__init__.py
+++ b/src/instructlab/training/__init__.py
@@ -22,9 +22,11 @@
 
 
 # defer import of main_ds
-def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
+def run_training(
+    torch_args: TorchrunArgs, train_args: TrainingArgs, device: str = "cuda"
+) -> None:
     """Wrapper around the main training job that calls torchrun."""
     # Local
     from .main_ds import run_training
 
-    return run_training(torch_args=torch_args, train_args=train_args)
+    return run_training(torch_args=torch_args, train_args=train_args, device=device)
diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py
index 07f33728..f47bfd16 100644
--- a/src/instructlab/training/main_ds.py
+++ b/src/instructlab/training/main_ds.py
@@ -16,11 +16,18 @@
 
 # pylint: disable=no-name-in-module
 from instructlab.dolomite.hf_models import GPTDolomiteForCausalLM
+from torch import nn
 from torch.distributed import ReduceOp, all_reduce
 from tqdm import tqdm
-from transformers import AutoModelForCausalLM, get_scheduler
+from transformers import (
+    Adafactor,
+    AutoModelForCausalLM,
+    LlamaForCausalLM,
+    get_scheduler,
+)
 import deepspeed
 import torch
+import torch.distributed
 
 # First Party
 from instructlab.training import config
@@ -83,7 +90,7 @@ def get_ds_config(world_size, samples_per_gpu, grad_accum, opts: DeepSpeedOption
     return ds_config
 
 
-def setup_model(args, tokenizer, train_loader, grad_accum):
+def setup_model(args, tokenizer, train_loader, grad_accum, device):
     bnb_config = None
     if args.lora_r > 0 and args.lora_quant_bits == 4:
         # Third Party
@@ -230,45 +237,56 @@ def make_inputs_require_grad(module, input, output):
             model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
 
     # need to use this only when the CPU offload optimizer is enabled
-    if args.cpu_offload_optimizer:
-        print(
-            "\033[33m!!! CPU offload optimizer enabled, using DeepSpeedCPUAdam !!!\033[0m"
-        )
-        optimizer = DeepSpeedCPUAdam(
-            model.parameters(), lr=args.learning_rate, betas=(0.9, 0.95)
-        )
-    else:
-        optimizer = FusedAdam(
-            model.parameters(), lr=args.learning_rate, betas=(0.9, 0.95)
-        )
+    if device.type == "cuda":
+        if args.cpu_offload_optimizer:
+            print(
+                "\033[33m!!! CPU offload optimizer enabled, using DeepSpeedCPUAdam !!!\033[0m"
+            )
+            optimizer = DeepSpeedCPUAdam(
+                model.parameters(), lr=args.learning_rate, betas=(0.9, 0.95)
+            )
+        else:
+            optimizer = FusedAdam(
+                model.parameters(), lr=args.learning_rate, betas=(0.9, 0.95)
+            )
 
-    lr_scheduler = get_scheduler(
-        name=args.lr_scheduler,
-        optimizer=optimizer,
-        num_warmup_steps=args.num_warmup_steps,
-        num_training_steps=args.num_epochs * len(train_loader) // grad_accum,
-    )
+        lr_scheduler = get_scheduler(
+            name=args.lr_scheduler,
+            optimizer=optimizer,
+            num_warmup_steps=args.num_warmup_steps,
+            num_training_steps=args.num_epochs * len(train_loader) // grad_accum,
+        )
 
     # pylint: disable=unbalanced-tuple-unpacking
-    model, _, _, lr_scheduler = deepspeed.initialize(
-        model=model,
-        optimizer=optimizer,
-        config=get_ds_config(
-            world_size=torch.distributed.get_world_size(),
-            samples_per_gpu=args.samples_per_gpu,
-            grad_accum=grad_accum,
-            opts=DeepSpeedOptions(
-                cpu_offload_optimizer=args.cpu_offload_optimizer,
-                cpu_offload_optimizer_ratio=args.cpu_offload_optimizer_ratio,
-                cpu_offload_optimizer_pin_memory=args.cpu_offload_optimizer_pin_memory,
-                save_samples=args.save_samples_ds,
+    optimizer = None
+    if device.type == "cuda":
+        model, _, _, lr_scheduler = deepspeed.initialize(
+            model=model,
+            optimizer=optimizer,
+            config=get_ds_config(
+                world_size=torch.distributed.get_world_size(),
+                samples_per_gpu=args.samples_per_gpu,
+                grad_accum=grad_accum,
+                opts=DeepSpeedOptions(
+                    cpu_offload_optimizer=args.cpu_offload_optimizer,
+                    cpu_offload_optimizer_ratio=args.cpu_offload_optimizer_ratio,
+                    cpu_offload_optimizer_pin_memory=args.cpu_offload_optimizer_pin_memory,
+                    save_samples=args.save_samples_ds,
+                ),
             ),
-        ),
-        lr_scheduler=lr_scheduler,
-        dist_init_required=True,
-    )
-    # model = torch.compile(model)
-    return model
+            lr_scheduler=lr_scheduler,
+            dist_init_required=True,
+        )
+    else:
+        # If we are using CPU or MPS just place model on that device
+        # also, initialize Adafactor, a Transformers Optimizer designed to use less resources.
+        # if we use AdamW here most people will always run out of RAM
+        model = model.to(device)
+        optimizer = Adafactor(
+            model.parameters(), lr=1e-5, scale_parameter=True, relative_step=False
+        )
+        model.gradient_checkpointing_enable()
+    return model, optimizer
 
 
 # this function is to check if the checkpoint provided can be resumed
@@ -331,7 +349,9 @@ def maybe_resume_training(args, model):
     return model
 
 
-def train(args, model, tokenizer, train_loader, grad_accum, metric_logger):
+def train(
+    args, model, tokenizer, train_loader, grad_accum, metric_logger, device, optimizer
+):
     model.train()
 
     global_step = 1
@@ -359,7 +379,8 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger):
         )
 
     for epoch in range(args.num_epochs):
-        torch.distributed.barrier()
+        if torch.cuda.is_available():
+            torch.distributed.barrier()
         if args.sampler in ("multipack"):
             train_loader.batch_sampler.set_epoch(epoch)
         elif args.sampler in ("distributed"):
@@ -370,7 +391,10 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger):
         if local_rank == 0:
             inner_pb = tqdm(range(len(train_loader)), desc=f"Epoch {epoch}")
 
-        aggregated_values = torch.zeros(3, dtype=torch.float32).to(local_rank)
+        if not torch.cuda.is_available():
+            aggregated_values = torch.zeros(3, dtype=torch.float16).to(device)
+        else:
+            aggregated_values = torch.zeros(3, dtype=torch.float32).to(local_rank)
         for batch in train_loader:
             if global_step <= args.last_step:
                 # in the case of resuming, last_step > 0
@@ -384,7 +408,10 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger):
             aggregated_values[1] = len(batch["input_ids"])
             if not args.is_granite:
                 for k in batch:
-                    batch[k] = batch[k].to(local_rank)
+                    if torch.cuda.is_available():
+                        batch[k] = batch[k].to(local_rank)
+                    else:
+                        batch[k] = batch[k].to(device=device)
 
             output = model(
                 **batch,
@@ -394,7 +421,8 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger):
 
             aggregated_values[2] = loss.item()
 
-            all_reduce(aggregated_values, op=ReduceOp.SUM)
+            if torch.cuda.is_available() and torch.distributed.is_initialized():
+                all_reduce(aggregated_values, op=ReduceOp.SUM)
 
             num_loss_counted_tokens = aggregated_values[0]
             loss = (
@@ -404,32 +432,65 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger):
             print(
                 f"\033[93mPer-token loss scaled by world size: {(loss/num_loss_counted_tokens) * world_size}\033[0m"
             )
-            print(
-                f"Epoch: {epoch}, Step: {global_step}, Rank: {torch.distributed.get_rank()}, loss = {loss}"
-            )
-
-            model.backward(loss)
-            model.step()
+            if torch.cuda.is_available():
+                rank = torch.distributed.get_rank()
+            else:
+                rank = 0
+            print(f"Epoch: {epoch}, Step: {global_step}, Rank: {rank}, loss = {loss}")
+
+            # If using a LlamaForCausalLM model (single device CPU, GPU, or MPS) then we cannot use the DS .backward, .step from the model_engine
+            # instead, use the AdaFactor Optimizer's zero_grad, the loss.backward() and step the optimizer itself.
+            if torch.cuda.is_available():
+                model.backward(loss)
+                model.step()
+            else:
+                optimizer.zero_grad()
+                loss.backward()
+                optimizer.step()
 
             if local_rank == 0:
                 elapsed_time = time.time() - start
                 overall_throughput = args.samples_per_gpu * world_size / elapsed_time
-                current_lr = model.lr_scheduler.get_last_lr()[0]
-                cuda_mem_allocated = torch.cuda.memory_allocated() / (1024**3)
-                cuda_malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"]
-                global_grad_norm = model.get_global_grad_norm()
+                cuda_malloc_retries = 0
+                cuda_mem_allocated = 0
+                if torch.cuda.is_available():
+                    cuda_mem_allocated = torch.cuda.memory_allocated() / (1024**3)
+                    cuda_malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"]
+                norm = None
+                if not isinstance(model, LlamaForCausalLM):
+                    global_grad_norm = model.get_global_grad_norm()
+                    norm = model.optimizer.single_partition_of_fp32_groups[0].norm()
+                    current_lr = model.lr_scheduler.get_last_lr()[0]
+                else:
+                    global_grad_norm = nn.utils.clip_grad_norm_(
+                        model.parameters(), max_norm=float("inf")
+                    )
+                    lr_scheduler = torch.optim.lr_scheduler.StepLR(
+                        optimizer, step_size=10, gamma=0.1
+                    )
+                    fp32_params = [
+                        param.data
+                        for param in model.parameters()
+                        if param.requires_grad
+                    ]
+                    norm = torch.norm(fp32_params[0])
+                    # for name, param in model.named_parameters():
+                    #    if param.requires_grad:
+                    #        fp32_weights = param.data
+                    #        fp32_norm = torch.norm(fp32_weights)
+                    #        print(f"Norm of {name}: {fp32_norm.item()}")
+                    current_lr = lr_scheduler.get_last_lr()[0]
                 global_grad_norm = (
                     float(global_grad_norm) if global_grad_norm is not None else None
                 )
-                weight_norm = float(
-                    model.optimizer.single_partition_of_fp32_groups[0].norm()
-                )
+
+                weight_norm = float(norm)
 
                 metric_logger.log_sync(
                     {
                         "epoch": epoch,
                         "step": global_step,
-                        "rank": torch.distributed.get_rank(),
+                        "rank": rank,
                         "loss": loss.item(),
                         "overall_throughput": overall_throughput,
                         "lr": current_lr,
@@ -470,7 +531,8 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger):
             global_step += 1
             if local_rank == 0:
                 inner_pb.update(1)
-            torch.cuda.empty_cache()
+            if torch.cuda.is_available():
+                torch.cuda.empty_cache()
 
         if args.checkpoint_at_epoch:
             save_hf_format_ds(
@@ -507,13 +569,30 @@ def main(args):
     # device = torch.device("cuda", args.local_rank)
 
     #### distributed init #####
-    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
-    args.local_rank = int(os.environ["LOCAL_RANK"])
-    deepspeed.init_distributed(timeout=timedelta(minutes=30))
-    args.global_rank = torch.distributed.get_rank()
-    tensor = torch.ByteTensor([False]).cuda()
-    torch.distributed.all_reduce(tensor)
-    torch.distributed.barrier()
+    world_size = 1
+    device = None
+    multiprocessing = None
+    if not torch.cuda.is_available():
+        if (
+            args.device == "mps"
+            and torch.backends.mps.is_available()
+            and torch.backends.mps.is_built()
+        ):
+            device = torch.device("mps")
+            multiprocessing = "fork"
+        else:
+            device = torch.device("cpu")
+        args.local_rank = 0
+        args.global_rank = 0
+    elif torch.distributed.is_available():
+        world_size = torch.distributed.get_world_size()
+        torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
+        args.local_rank = int(os.environ["LOCAL_RANK"])
+        deepspeed.init_distributed(timeout=timedelta(minutes=10))
+        args.global_rank = torch.distributed.get_rank()
+        tensor = torch.ByteTensor([False]).cuda()
+        torch.distributed.all_reduce(tensor)
+        torch.distributed.barrier()
 
     dataset = setup_dataset(
         args.data_path,
@@ -523,7 +602,7 @@ def main(args):
 
     try:
         packing_max_batch_len, grad_accum = find_packing_max_batch_len_and_grad_accum(
-            num_gpus=torch.distributed.get_world_size(),
+            num_gpus=world_size,
             avg_sample_len=dataset.get_lengths().mean(),
             effective_batch_size=args.effective_batch_size,
             max_batch_len_per_gpu=args.max_batch_len,
@@ -542,12 +621,11 @@ def main(args):
         grad_accum = 1
         args.sampler = "distributed"
 
-    args.samples_per_gpu = (
-        args.effective_batch_size // grad_accum // torch.distributed.get_world_size()
-    )
+    args.samples_per_gpu = args.effective_batch_size // grad_accum // world_size
 
     train_loader = setup_dataloader(
         dataset,
+        multiprocessing,
         tokenizer.pad_token_id,
         num_workers=8,
         is_granite=args.is_granite,
@@ -567,6 +645,7 @@ def main(args):
         args.sampler = "distributed"
         train_loader = setup_dataloader(
             dataset,
+            multiprocessing,
             tokenizer.pad_token_id,
             num_workers=8,
             is_granite=args.is_granite,
@@ -580,7 +659,7 @@ def main(args):
     if args.local_rank == 0:
         metric_logger.log_sync(
             {
-                "num_gpus": torch.distributed.get_world_size(),
+                "num_gpus": world_size,
                 "avg_sample_len": dataset.get_lengths().mean(),
                 "effective_batch_size": args.effective_batch_size,
                 "max_batch_len_per_gpu": args.max_batch_len,
@@ -592,17 +671,30 @@ def main(args):
             }
         )
 
-    model = setup_model(args, tokenizer, train_loader, grad_accum)
-    model = maybe_resume_training(args, model)
-
-    train(args, model, tokenizer, train_loader, grad_accum, metric_logger)
+    model, optimizer = setup_model(args, tokenizer, train_loader, grad_accum, device)
+    if device.type == "cuda":
+        model = maybe_resume_training(args, model)
+
+    train(
+        args,
+        model,
+        tokenizer,
+        train_loader,
+        grad_accum,
+        metric_logger,
+        device,
+        optimizer,
+    )
 
-    torch.distributed.barrier()
-    torch.distributed.destroy_process_group()
+    if torch.cuda.is_available() and torch.distributed.is_available():
+        torch.distributed.barrier()
+        torch.distributed.destroy_process_group()
 
 
 # public API
-def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
+def run_training(
+    torch_args: TorchrunArgs, train_args: TrainingArgs, device: str = "cuda"
+) -> None:
     """
     Wrapper around the main training job that calls torchrun.
     """
@@ -705,6 +797,11 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
         if train_args.deepspeed_options.cpu_offload_optimizer_pin_memory:
             command.append("--cpu_offload_optimizer_pin_memory")
 
+    if torch_args.nproc_per_node == 1:
+        command.append("--standalone")
+
+    command.extend(["--device", device])
+
     print(f"\033[92mRunning command: {' '.join(command)}\033[0m")
     process = None
     try:
@@ -831,6 +928,8 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
         ),
     )
     parser.add_argument("--disable_flash_attn", action="store_true")
+    parser.add_argument("--standalone", action="store_true")
+    parser.add_argument("--device", type=str, default="cuda")
     args = parser.parse_args()
     set_random_seed(args.seed)
     main(args)
diff --git a/src/instructlab/training/multipack_sampler.py b/src/instructlab/training/multipack_sampler.py
index 71d1def2..1db8ab33 100644
--- a/src/instructlab/training/multipack_sampler.py
+++ b/src/instructlab/training/multipack_sampler.py
@@ -30,7 +30,6 @@
 from torch.utils.data import Sampler
 import numba
 import numpy as np
-import torch
 import torch.distributed as dist
 
 
@@ -67,11 +66,16 @@ def get_effective_samples_per_minibatch(num_tokens_per_gpu):
 
         The function creates a sampler using the MultipackDistributedBatchSampler class, generates batches using the sampler, and then returns the ratio of the dataset size to the number of batches.
         """
+        num_replicas = 1
+        rank = 0
+        if dist.is_initialized():
+            num_replicas = dist.get_world_size()
+            rank = dist.get_rank()
         sampler = MultipackDistributedBatchSampler(
             batch_max_length=num_tokens_per_gpu,
             lengths=dataset.get_lengths(),
-            num_replicas=torch.distributed.get_world_size(),
-            rank=torch.distributed.get_rank(),
+            num_replicas=num_replicas,
+            rank=rank,
             seed=seed,
             padding=True,
         )
diff --git a/src/instructlab/training/token_dataset.py b/src/instructlab/training/token_dataset.py
index 9d46607e..44802694 100644
--- a/src/instructlab/training/token_dataset.py
+++ b/src/instructlab/training/token_dataset.py
@@ -85,6 +85,7 @@ def setup_dataset(
 
 def setup_dataloader(
     dataset: Dataset,
+    multiprocessing: str,
     pad_token_id: int,
     num_workers: int = 8,
     is_granite=False,
@@ -128,6 +129,7 @@ def setup_dataloader(
     dataloader = DataLoader(
         dataset,
         **sampler,
+        multiprocessing_context=multiprocessing,
         num_workers=num_workers,
         collate_fn=collate_fn,
     )
diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py
index 89fd3c24..be7287f5 100644
--- a/src/instructlab/training/utils.py
+++ b/src/instructlab/training/utils.py
@@ -28,8 +28,6 @@
 )
 from rich.logging import RichHandler
 from safetensors.torch import save_file
-from torch import distributed as dist
-from torch.distributed import get_rank, is_initialized
 from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
     CheckpointImpl,
     apply_activation_checkpointing,
@@ -37,6 +35,7 @@
 )
 import numpy as np
 import torch
+import torch.distributed
 import torch.nn.functional as F
 
 
@@ -480,7 +479,7 @@ class UniversalCheckpointArgs:
         with open(latest_file, "w") as f:
             f.write(step_folder)
 
-    dist.barrier()
+    torch.distributed.barrier()
     log_rank_0(f"Preparing universal checkpoint took {time.time() - start} seconds")
 
 
@@ -508,26 +507,28 @@ def ensure_loadable_granite_checkpoint(
         # Assumption: tmpdir should be accessible by all ranks, even those
         # in different nodes
         tmpdir = Path(tmpdir) / f"tmp.{group_rank}"
-        if os.path.exists(tmpdir) and (not dist.is_initialized() or local_rank == 0):
-            # need to delete if it exists because import doesn't like it to
+        if os.path.exists(tmpdir) and (
+            not torch.distributed.is_initialized() or local_rank == 0
+        ):
+            # need to delete if it exists because import doesnt like it to
             shutil.rmtree(tmpdir, ignore_errors=True)
 
-        if not dist.is_initialized() or local_rank == 0:
+        if not torch.distributed.is_initialized() or local_rank == 0:
             import_from_huggingface(model_name_or_path, tmpdir)
 
-        if dist.is_initialized():
+        if torch.distributed.is_initialized():
             # the first barrier is to wait for local rank 0 to finish converting the model
             # and place into tmpdir
-            dist.barrier()
+            torch.distributed.barrier()
 
         # return tmpdir out for loading
         yield tmpdir
 
-        if dist.is_initialized():
+        if torch.distributed.is_initialized():
             # the second barrier is to wait for all the models to finish loading
-            dist.barrier()
+            torch.distributed.barrier()
 
-        if not dist.is_initialized() or local_rank == 0:
+        if not torch.distributed.is_initialized() or local_rank == 0:
             # at this point, we can be confident that the tmpdir is no longer needed
             shutil.rmtree(tmpdir, ignore_errors=True)
 
@@ -603,7 +604,7 @@ def get_caller(num_frames=1):
 
 def log_rank_0(msg, include_caller=False, rank=None, to_print=False):
     if rank is None:
-        rank = get_rank() if is_initialized() else 0
+        rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
     if rank <= 0:
         if include_caller:
             msg = f"{get_caller(num_frames=2)}: {msg}"
@@ -632,7 +633,11 @@ def save_hf_format_ds(
     convert_granite=True,
     is_lora=False,
 ):
-    model_to_save = model.module
+    if torch.cuda.is_available():
+        model_to_save = model.module
+    else:
+        # if not using DS, the model is an actual model not model_engine
+        model_to_save = model
     log_rank_0(
         f"\033[93mSaving model in huggingface format at samples_seen: {samples_seen}\033[0m",
         to_print=True,
@@ -647,7 +652,7 @@ def save_hf_format_ds(
     else:
         WEIGHTS_NAME = "pytorch_model.bin"
     output_dir = Path(args.output_dir) / "hf_format" / f"samples_{samples_seen}"
-    if torch.distributed.get_rank() == 0:
+    if not torch.cuda.is_available() or torch.distributed.get_rank() == 0:
         if is_lora:
             model_to_save.merge_adapter()
 
@@ -686,7 +691,8 @@ def save_hf_format_ds(
         if is_lora:
             model_to_save.unmerge_adapter()
 
-    dist.barrier()
+    if torch.cuda.is_available() and torch.distributed.is_initialized():
+        torch.distributed.barrier()
     log_rank_0(f"\033[93mModel saved in {output_dir}\033[0m", to_print=True)
     log_rank_0(f"saving took {time.time() - start} seconds")