Skip to content

enable fine tuning on HPU #552

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions src/instructlab/training/hpu_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
from functools import lru_cache


@lru_cache(maxsize=None)
def is_torch_hpu_available() -> bool:
try:
import habana_frameworks.torch.core # noqa: F401
except ImportError:
return False
return True


def simple_bucket(length):
"""
This bucket algorithm merely relies on the given number instead of based on
slicing the known (min, max) range for several reasons:
1) Due to the use of the first-fit-decreasing (FFD) algorithm, the
(min, max) sequence length of each rank will be much smaller than the
(min, max) sequence length of the dataset. Bucketing on the
(min, max) sequence length of the dataset is not practical
2) The (min, max) sequence length of a given rank is unknown until
finishing 1 epoch since the packing is done on the fly
3) Due to the shuffling, the (min, max) sequence length of a given rank
may vary between ranks. Once the (min, max) sequence length of a
given rank changes, the bucketing also needs adjustment

This bucket algorithm is based on the most significant set bit of the input number.
It first check what’s the most significant set bit, assuming it's bit "S",
and then slice the range [2 ** S, 2 ** (S+1)] into buckets with the same size.
By default the range is divided into 16 buckets, so the bucket size will be
2 ** (S - 4)
For example, 0b10001 will be padded to 0b10010.
This approach can limit the overhead of bucketing (at most 1/16 of the input
number) and also prevent recompilation due to a too small bucket size.
"""
l = length
msb = 0
while l > 0:
msb += 1
l = l // 2

align = (1 << (msb - 4)) if msb >= 4 else 1

return (length + align - 1) // align * align


def bucket(length):
return simple_bucket(length)
82 changes: 72 additions & 10 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@
UserWarning,
)

from instructlab.training.hpu_utils import is_torch_hpu_available

if is_torch_hpu_available():
import habana_frameworks.torch.core as htcore
import habana_frameworks.torch.distributed.hccl
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
adapt_transformers_to_gaudi()

# Third Party
from instructlab.dolomite.hf_models import GPTDolomiteForCausalLM
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -174,6 +182,13 @@ def setup_model(
else:
model = AutoModelForCausalLM.from_pretrained(**base_model_args)

if is_torch_hpu_available() and os.getenv("HPU_ENABLE_TORCH_COMPILE", False):
torch._dynamo.config.cache_size_limit = int(1e4)
torch._dynamo.config.accumulated_cache_size_limit = int(2e4)
model = torch.compile(model, backend="hpu_backend", dynamic=False)
for layer in model.model.layers:
layer.compile(backend="hpu_backend", dynamic=False)

# store the base model args so we can recall them later if saving a LoRA model
args.base_model_args = base_model_args

Expand Down Expand Up @@ -222,7 +237,22 @@ def setup_model(
)
model.config.eos_token_id = tokenizer.eos_token_id

if "ForCausalLM" not in model.__class__.__name__:
if not is_torch_hpu_available():
class_name = model.__class__.__name__
else:
class_name = model._orig_mod.__class__.__name__ if model.__class__.__name__ == 'OptimizedModule' else model.__class__.__name__

replace_no_split_modules = {
'GaudiLlamaForCausalLM': ['GaudiLlamaDecoderLayer',]
}

if class_name in replace_no_split_modules:
if model.__class__.__name__ == 'OptimizedModule':
model._orig_mod._no_split_modules = replace_no_split_modules[class_name]
else:
model._no_split_modules = replace_no_split_modules[class_name]

if "ForCausalLM" not in class_name:
raise ValueError(
f"Model class name: {model.__class__.__name__} is not supported."
)
Expand Down Expand Up @@ -272,6 +302,11 @@ def make_inputs_require_grad(module, input, output): # pylint: disable=unused-a
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

accelerator = setup_accelerator(args, model, grad_accum)

if is_torch_hpu_available():
accelerator.state.fsdp_plugin.use_orig_params=True
accelerator.state.fsdp_plugin.sync_module_states=True

if args.distributed_training_framework == DistributedBackend.FSDP.value:
model = accelerator.prepare(model)
optimizer = setup_optimizer(args, model)
Expand Down Expand Up @@ -414,10 +449,19 @@ def train(
total_length = float(torch.tensor([batch.pop("total_length")]))
if not args.use_dolomite:
for k in batch:
batch[k] = batch[k].to(local_rank)
batch[k] = batch[k].to('hpu' if is_torch_hpu_available() else local_rank)

hpu_args = []
if is_torch_hpu_available():
hpu_args = {
"use_flash_attention":True,
"lazy_mode":False,
}

output = model(
**batch,
use_cache=False,
**hpu_args,
)
loss = output.loss
log_loss = loss.detach().item()
Expand Down Expand Up @@ -454,8 +498,14 @@ def train(
elapsed_time = time.time() - start
overall_throughput = args.samples_per_gpu * world_size / elapsed_time
current_lr = lr_scheduler.get_last_lr()[0]
cuda_mem_allocated = torch.cuda.memory_allocated() / (1024**3)
cuda_malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"]

if is_torch_hpu_available():
mem_allocated = torch.hpu.memory_allocated() / (1024**3)
malloc_retries = 0
else:
mem_allocated = torch.cuda.memory_allocated() / (1024**3)
malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"]

global_grad_norm = (
model.get_global_grad_norm()
if hasattr(model, "get_global_grad_norm")
Expand All @@ -477,8 +527,8 @@ def train(
"rank": torch.distributed.get_rank(),
"overall_throughput": overall_throughput,
"lr": current_lr,
"cuda_mem_allocated": cuda_mem_allocated,
"cuda_malloc_retries": cuda_malloc_retries,
("hpu" if is_torch_hpu_available() else "cuda") + "_mem_allocated": mem_allocated,
("hpu" if is_torch_hpu_available() else "cuda") + "_malloc_retries": malloc_retries,
"num_loss_counted_tokens": int(num_loss_counted_tokens),
"num_tokens_rank0": int(total_length),
"batch_size": int(micro_batch_size),
Expand Down Expand Up @@ -519,7 +569,10 @@ def train(
global_step += 1
if local_rank == 0:
inner_pb.update(1)
torch.cuda.empty_cache()

if not is_torch_hpu_available():
torch.cuda.empty_cache()

if args.checkpoint_at_epoch:
base_logger.debug(f"Saving checkpoint at epoch {epoch}")
save_checkpoint(
Expand Down Expand Up @@ -595,18 +648,27 @@ def main(args):
args.model_type = model_conf.model_type

#### distributed init #####
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
if is_torch_hpu_available():
torch.hpu.set_device(int(os.environ["LOCAL_RANK"]))
else:
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

args.local_rank = int(os.environ["LOCAL_RANK"])

timeout = _get_collective_timeout()
init = functools.partial(torch.distributed.init_process_group, "nccl")
init = functools.partial(torch.distributed.init_process_group, "hccl" if is_torch_hpu_available() else "nccl")
if timeout is not None:
init(timeout=timeout)
else:
init()

args.global_rank = torch.distributed.get_rank()
tensor = torch.ByteTensor([False]).cuda()

if is_torch_hpu_available():
tensor = torch.ByteTensor([False]).to('hpu')
else:
tensor = torch.ByteTensor([False]).cuda()

torch.distributed.all_reduce(tensor)
torch.distributed.barrier()

Expand Down
9 changes: 8 additions & 1 deletion src/instructlab/training/multipack_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
import torch
import torch.distributed as dist

from instructlab.training.hpu_utils import is_torch_hpu_available, bucket


def find_max_pack_len_with_padding(
dataset,
Expand Down Expand Up @@ -68,9 +70,14 @@ def get_effective_samples_per_minibatch(num_tokens_per_gpu):

The function creates a sampler using the MultipackDistributedBatchSampler class, generates batches using the sampler, and then returns the ratio of the dataset size to the number of batches.
"""
lengths=dataset.get_lengths()
if is_torch_hpu_available():
bucket_v = np.vectorize(bucket)
lengths = bucket_v(lengths)

sampler = MultipackDistributedBatchSampler(
batch_max_length=num_tokens_per_gpu,
lengths=dataset.get_lengths(),
lengths=lengths,
num_replicas=torch.distributed.get_world_size(),
rank=torch.distributed.get_rank(),
seed=seed,
Expand Down
16 changes: 12 additions & 4 deletions src/instructlab/training/setup_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from functools import partial

# Third Party
from accelerate import Accelerator
from peft.utils.other import fsdp_auto_wrap_policy
from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
Expand All @@ -12,6 +11,12 @@
# First Party
from instructlab.training.config import DeepSpeedOptions
from instructlab.training.utils import get_module_class_from_name, patch_target_module
from instructlab.training.hpu_utils import is_torch_hpu_available

if is_torch_hpu_available():
from optimum.habana.accelerate import GaudiAccelerator
else:
from accelerate import Accelerator


def get_ds_plugin(world_size, samples_per_gpu, grad_accum, opts: DeepSpeedOptions):
Expand Down Expand Up @@ -51,7 +56,10 @@ def get_ds_plugin(world_size, samples_per_gpu, grad_accum, opts: DeepSpeedOption

def get_fsdp_config(args, model: PreTrainedModel):
# Third Party
from accelerate.utils import FullyShardedDataParallelPlugin
if is_torch_hpu_available():
from optimum.habana.accelerate.utils import GaudiFullyShardedDataParallelPlugin
else:
from accelerate.utils import FullyShardedDataParallelPlugin
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload

is_lora = args.lora_r > 0
Expand All @@ -73,7 +81,7 @@ def get_fsdp_config(args, model: PreTrainedModel):
prefetch_policy = (
BackwardPrefetch.BACKWARD_POST if is_lora else BackwardPrefetch.BACKWARD_PRE
)
fsdp_plugin = FullyShardedDataParallelPlugin(
fsdp_plugin = (GaudiFullyShardedDataParallelPlugin if is_torch_hpu_available() else FullyShardedDataParallelPlugin)(
auto_wrap_policy=wrap_policy,
limit_all_gathers=True,
backward_prefetch=prefetch_policy,
Expand Down Expand Up @@ -128,7 +136,7 @@ def setup_accelerator(args, model: PreTrainedModel, grad_accum):
raise ValueError(
f"Unknown sharding framework: {args.distributed_training_framework}"
)
accelerator = Accelerator(
accelerator = (GaudiAccelerator if is_torch_hpu_available() else Accelerator)(
**accel_args,
)
accelerator.even_batches = False
Expand Down
5 changes: 5 additions & 0 deletions src/instructlab/training/token_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from instructlab.training.multipack_sampler import MultipackDistributedBatchSampler
from instructlab.training.utils import log_rank_0, make_collate_fn

from instructlab.training.hpu_utils import is_torch_hpu_available, bucket

class TokenDataset(Dataset):
def __init__(self, data_path):
Expand Down Expand Up @@ -109,6 +110,10 @@ def setup_dataloader(

lengths = dataset.get_lengths()
if sampler == "multipack":
if is_torch_hpu_available():
bucket_v = np.vectorize(bucket)
lengths = bucket_v(lengths)

sampler = MultipackDistributedBatchSampler(
batch_max_length=packing_max_batch_len,
lengths=lengths,
Expand Down
13 changes: 12 additions & 1 deletion src/instructlab/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
QuantizeDataType,
TrainingArgs,
)
from instructlab.training.hpu_utils import is_torch_hpu_available, bucket

logger = logging.getLogger("instructlab.training")

Expand Down Expand Up @@ -209,6 +210,9 @@ def listen(self):


def supports_flash_attention(device_id=0):
if is_torch_hpu_available():
return False

"""Check if a GPU supports FlashAttention."""
major, minor = torch.cuda.get_device_capability(device_id)
# Check if the GPU architecture is Ampere (SM 8.x) or newer (SM 9.0)
Expand Down Expand Up @@ -300,6 +304,9 @@ def pad_collate_fn(batch):
lens = np.array([len(item["input_ids"]) for item in batch])
max_len = max(lens)

if is_torch_hpu_available():
max_len = bucket(max_len)

input_ids = torch.stack(
[
F.pad(
Expand Down Expand Up @@ -411,6 +418,7 @@ def reduce_sum_forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**_deprecated_arguments if is_torch_hpu_available() else None,
)

return_dict = isinstance(output, dict)
Expand Down Expand Up @@ -1093,7 +1101,10 @@ def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if is_torch_hpu_available():
torch.hpu.manual_seed_all(seed)
else:
torch.cuda.manual_seed_all(seed)


def save_checkpoint(
Expand Down
Loading