-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simple float CUDA port #376
base: master
Are you sure you want to change the base?
Conversation
builds & perf basically matches 'make run'
using float, not half, though. Trying for an apples-to-apples comparison.
- output is gibberish - very, very slow - doesn't trigger anything in compute-sanitizer
found < that needed to be a <= and used some llama2.cu code. Need to grok why still.
still can't find issue with suspect code.
Note why multi_head_attention_kernel uses llama2.cu code instead
added some TODO items & notes
Thank you for the PR! I'm traveling right now so a bit slower on reply, but looking forward to taking a look |
Found a straightforward way for anyone to see the perf impact of CUDA via Google Colab. First, bring up a Google Colab notebook and select a GPU runtime. Next, wget the necessary bits from the repo + stories110M
Build
Run cuda & run C
I get matching output & about 320-350 tok/s for CUDA and 20-30 for C. |
@rogerallen I don't have much RAM on my GPU (10GB). For the llama_7b model I put in some debug, and first cudaMalloc is ~7GB, then second cudaMalloc is ~6GB I notice the comment: // allocate & copy mmap data to the gpu first
// TODO: allocate & copy just a portion to the GPU if the weights are too big
// to fit in the GPU, then copy the data only as needed while running. Can it be done? Will it work for my modest GPU? |
I have not implemented this feature. I wrote that comment without really thinking through the situation. Considering that we have to push all the weights through the GPU every forward() pass to generate a token this probably isn't that interesting to do. My suggestion would be to look at some of the other CUDA ports using FP16 & FP8 weights. Those should divide the weights by 2x or 4x, hopefully fit in your GPU and provide interesting performance. |
I tried run-q8.cu but it was just segfaulting, then tried using his |
This one works well! Pure Cuda inference for 4-bit AWQ quantized models. |
This is my attempt at a simple port to CUDA. I'm hopeful this can serve as an example for anyone who wants to learn how to use CUDA for LLMs.
I was inspired & have used code from https://github.com/ankan-ban/llama2.cu as a starting point. When I came upon that code, I noticed the upstream repository had progressed and their
llama2.cu
code was no longer working. So, I tried to restructure the CUDA code in a way thatrun.cu
could be kept up-to-date viadiff
. So far, keeping up hasn't been too bad.The
ankan-ban
repository also seems focused on making 16-bit float and 8-bit int work. That is very cool and I hope they continue, but I hoped there would be room for a straight float port. It is not my intent to step on any toes here & I'm sorry if this comes across that way. I just think this might mesh up with the existing code better.Even with a simple port like this, I do notice a significant performance increase. Using the
stories110M.bin
, I am seeing a 5x perf increase vs therunomp
app on my 14-core i9 Intel laptop with an NVIDIA RTX 4050 running WSL2. My older linux desktop 12-core i9 system with a NVIDIA RTX 3070 sees about 15x.Both Linux & Windows builds are working.
To make it easy to compare the C and CUDA, I extracted each function inside the
forward
routine and wrapped it with aUSE_CUDA
define to allow easy comparison from C to CUDA.I used
cuBLAS
to leverage that library's expertise for the SGEMV function. It adds some startup time overhead viacublasCreate
, so I'm waffling on keeping that code at the moment. I might go back to the previousmatmul
kernel code fromankan-ban
.In the rest of the code, I tried to keep it mostly untouched. But, since
nvcc
is a C++ compiler, there were a few times that we had to cast values in order to avoid errors. To get the Windows build working, I worked around one bit of C-syntax that madecl.exe
unhappy.I'm not very familiar with github pull requests, so bear with me if I have anything wrong.