Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

8-bit training #1

Open
yctam opened this issue Apr 1, 2024 · 1 comment
Open

8-bit training #1

yctam opened this issue Apr 1, 2024 · 1 comment

Comments

@yctam
Copy link

yctam commented Apr 1, 2024

Does the codebase support 8-bit training similar to peft library?

I was trying to fine-tune on llama2-7b on 24Gb 4090 cards. Below is the error I got:
File "/home/nlp/JORA/examples/train.py", line 14, in
main()
File "/home/nlp/JORA/examples/train.py", line 10, in main
train_lora(config, dataset, 'checkpoints')
File "/home/nlp/JORA/jora/common.py", line 246, in train_lora
lora_params, opt_state, total_loss, loss, key = train_step_lora(lora_params, loraConfig, params, opt_state, total_loss, data_batch, key)
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 16320586680 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 3.58GiB
constant allocation: 1.95MiB
maybe_live_out allocation: 264.00MiB
preallocated temp allocation: 15.20GiB
total allocation: 19.04GiB

@aniquetahir
Copy link
Owner

aniquetahir commented Apr 2, 2024

For now its using bfloat16. The main reason being no bitsandbytes equivalent for JAX yet. However, there is also some potential for inclusion of 8bit through TransformerEngine.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants