Skip to content

Commit

Permalink
Merge branch 'karpathy:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
dagelf authored Jun 13, 2024
2 parents 64c1898 + 95cef79 commit aabb5c9
Show file tree
Hide file tree
Showing 29 changed files with 745 additions and 335 deletions.
27 changes: 27 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,33 @@ jobs:
call "C:\\Program Files\\Microsoft Visual Studio\\2022\\Enterprise\\VC\\Auxiliary\\Build\\vcvars64.bat"
make-4.4.1\dist\make -j WIN_CI_BUILD=1 train_gpt2fp32cu test_gpt2fp32cu test_gpt2cu train_gpt2cu profile_gpt2cu
build-ubuntu20-04:
runs-on: ubuntu-20.04
container:
image: nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04
steps:
- name: Checkout code
uses: actions/checkout@v4

- name: System Info
run: |
nvcc --version
g++ --version
- name: Install cudnn frontend
run: |
apt-get update && apt-get install -y git
git clone https://github.com/NVIDIA/cudnn-frontend.git
- name: Build FP32 checkpoint
run: make train_gpt2fp32cu test_gpt2fp32cu

- name: Build FP32 precision
run: PRECISION=FP32 make train_gpt2cu test_gpt2cu profile_gpt2cu

- name: Build with CUDNN
run: PRECISION=BF16 USE_CUDNN=1 make train_gpt2cu test_gpt2cu profile_gpt2cu

build-cuda-fp32:
runs-on: ubuntu-latest
container:
Expand Down
112 changes: 112 additions & 0 deletions .github/workflows/ci_gpu.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
name: GPU Builds and Tests

on:
create:
workflow_dispatch:
push:
branches:
- master
pull_request:
branches:
- master

jobs:
build-and-test-gpu:
runs-on: ubicloud-gpu-standard-1-latest

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Install OpenMP
run: sudo apt-get update && sudo apt-get install -y libomp-dev

- name: Install dependencies
run: pip install -r requirements.txt

- name: Run preprocessing
run: python dev/data/tinyshakespeare.py

- name: Train model
run: python train_gpt2.py

- name: Compile training and testing program
run: make test_gpt2cu train_gpt2cu test_gpt2fp32cu train_gpt2fp32cu

- name: Train model (With OpenMP)
run: OMP_NUM_THREADS=8 ./train_gpt2cu

- name: Train model (FP32) with gpt2_124M.bin
run: |
PRECISION=FP32 make train_gpt2cu
./train_gpt2cu -b 4 -t 64 -l 1e-4 -v 200 -s 200 -a 1 -x 10 -e gpt2_124M.bin
- name: Build FP32 precision
run: PRECISION=FP32 make test_gpt2cu profile_gpt2cu

- name: Run default
run: ./test_gpt2cu

- name: Run no recompute GeLU
run: ./test_gpt2cu -r 0

- name: Run recompute LN
run: ./test_gpt2cu -r 2

- name: Build BF16 precision
run: PRECISION=BF16 make train_gpt2cu test_gpt2cu profile_gpt2cu

- name: Run default
run: ./test_gpt2cu

- name: Run no recompute GeLU
run: ./test_gpt2cu -r 0

- name: Run no master weights
run: ./test_gpt2cu -w 0

- name: Run recompute LN
run: ./test_gpt2cu -r 2

- name: Train model fp32 (With OpenMP)
run: OMP_NUM_THREADS=8 ./train_gpt2fp32cu

- name: Execute testing program (With OpenMP)
run: OMP_NUM_THREADS=8 ./test_gpt2cu

- name: Execute testing program fp32 (With OpenMP)
run: OMP_NUM_THREADS=8 ./test_gpt2fp32cu

- name: Compile training and testing program without OpenMP
run: NO_OMP=1 make test_gpt2cu train_gpt2cu test_gpt2fp32cu train_gpt2fp32cu

- name: Train model (No OpenMP)
run: NO_OMP=1 ./train_gpt2cu

- name: Train model fp32 (No OpenMP)
run: NO_OMP=1 ./train_gpt2fp32cu

- name: Execute testing program (No OpenMP)
run: ./test_gpt2cu -b 32

- name: Execute testing program fp32 (No OpenMP)
run: ./test_gpt2fp32cu

- name: Install cuDNN-frontend
run:
git clone https://github.com/NVIDIA/cudnn-frontend.git

- name: Build with cuDNN
run: USE_CUDNN=1 make test_gpt2cu train_gpt2cu test_gpt2fp32cu train_gpt2fp32cu

- name: Train model with cuDNN
run: ./train_gpt2cu

- name: Train model fp32 with cuDNN
run: ./train_gpt2fp32cu

- name: Execute testing program with cuDNN
run: ./test_gpt2cu

- name: Execute testing program fp32 with cuDNN
run: ./test_gpt2fp32cu
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ CUDA_OUTPUT_FILE = -o $@

# NVCC flags
# -t=0 is short for --threads, 0 = number of CPUs on the machine
NVCC_FLAGS = -O3 -t=0 --use_fast_math
NVCC_FLAGS = -O3 -t=0 --use_fast_math -std=c++17
NVCC_LDFLAGS = -lcublas -lcublasLt
NVCC_INCLUDES =
NVCC_LDLIBS =
Expand Down
2 changes: 1 addition & 1 deletion dev/cuda/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ endif

# Compiler flags
CFLAGS = -O3 --use_fast_math
NVCCFLAGS = -lcublas -lcublasLt
NVCCFLAGS = -lcublas -lcublasLt -std=c++17
MPI_PATHS = -I/usr/lib/x86_64-linux-gnu/openmpi/include -L/usr/lib/x86_64-linux-gnu/openmpi/lib/

# Default rule for our CUDA files
Expand Down
2 changes: 2 additions & 0 deletions dev/cuda/matmul_backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -268,12 +268,14 @@ int main(int argc, char **argv) {
free(dout);
free(inp);
free(weight);
free(ones);
cudaCheck(cudaFree(d_dinp));
cudaCheck(cudaFree(d_dweight));
cudaCheck(cudaFree(d_dbias));
cudaCheck(cudaFree(d_dout));
cudaCheck(cudaFree(d_inp));
cudaCheck(cudaFree(d_weight));
cudaCheck(cudaFree(d_ones));
cublasCheck(cublasDestroy(cublas_handle));

return 0;
Expand Down
14 changes: 8 additions & 6 deletions dev/cuda/matmul_backward_bias.cu
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ __global__ void matmul_backward_bias_kernel1(floatX* dbias, const floatX* dout,
}
// write the final result (at thread 0) to global memory
if (tid == 0) {
dbias[o] = (float)dbias[o] + shared[0];
dbias[o] = (floatX)((float)dbias[o] + shared[0]);
}
}

Expand All @@ -116,7 +116,7 @@ __global__ void matmul_backward_bias_kernel2(floatX* dbias, const floatX* dout,
sum = cg::reduce(warp, sum, cg::plus<float>{});
// write the result to output (global memory)
if(warp.thread_rank() == 0) {
dbias[idx] += sum;
dbias[idx] += (floatX)sum;
}
}

Expand All @@ -132,12 +132,13 @@ __global__ void matmul_backward_bias_kernel3(floatX* dbias, const floatX* dout,
int warp_id = threadIdx.x / 32;
int lane_id = threadIdx.x % 32;
int idx = blockIdx.x; // simply one block per row
// round 1: thread coarsening to reduce the problem size from B*T to 32
// round 1: thread coarsening to reduce the problem size from B*T to block_size
float thread_sum = 0.0f;
for(int i = threadIdx.x; i < BT; i += blockDim.x) {
thread_sum += (float)dout[i * OC + idx];
}
// now do a warp-level reduce to get the sum across the 32 threads in each warp
// reduce the problem size from block_size to block_size/32 i.e. `num_warps`
float warp_sum = cg::reduce(warp, thread_sum, cg::plus<float>{});
// store the warp sum in shared memory (we could have lane_id == 0 guard but not needed)
shared_sum[warp_id] = warp_sum;
Expand Down Expand Up @@ -167,7 +168,7 @@ __global__ void matmul_backward_bias_kernel4(floatX* dbias, const floatX* dout,
const int vstep = blockDim.x / warpSize; // number of warps in a block, e.g. 4

// pointer to the start of the column for one lane of threads
// so e.g. 4 threads (of the same lane_id) will reduce this one column
// so e.g. 4 (`vstep`) threads (of the same lane_id) will reduce this one column
const floatX* dout_col = dout + tl + lane_id;

// column reductions by looping through the rows
Expand Down Expand Up @@ -503,7 +504,7 @@ void matmul_backward_bias7(floatX* dbias, const floatX* dout,

assert(block_size_y >= x128::size); // part of the kernel assumes this is large enough to avoid loops

cudaCheck(cudaMemsetAsync(dbias_buffer, 0, OC * sizeof(float)));
cudaCheck(cudaMemset(dbias_buffer, 0, OC * sizeof(float)));
matmul_backward_bias_kernel7<<<dim3(grid_size_x, grid_size_y),
dim3(block_size_x, block_size_y), OC_per_warp * sizeof(float)>>>(dbias_buffer, dout, B, T, OC, block_size);
cudaCheck(cudaGetLastError());
Expand All @@ -524,7 +525,7 @@ void matmul_backward_bias8(floatX* dbias, const floatX* dout,
matmul_backward_bias_kernel8<<<dim3(grid_size_x, grid_size_y), block_dim>>>(dbias, dout, B, T, OC, std::bool_constant<false>{});
cudaCheck(cudaGetLastError());
} else {
cudaCheck(cudaMemsetAsync(dbias_buffer, 0, OC * sizeof(float)));
cudaCheck(cudaMemset(dbias_buffer, 0, OC * sizeof(float)));
matmul_backward_bias_kernel8<<<dim3(grid_size_x, grid_size_y), block_dim>>>(dbias_buffer, dout, B, T, OC, std::bool_constant<true>{});
cudaCheck(cudaGetLastError());
cast_and_add_kernel<<<ceil_div(OC, 256), 256, 0>>>(dbias, dbias_buffer, OC);
Expand Down Expand Up @@ -661,6 +662,7 @@ int main(int argc, char **argv) {
// cleanups
free(dbias);
free(dout);
cudaCheck(cudaFree(dbias_buffer));
cudaCheck(cudaFree(d_dbias));
cudaCheck(cudaFree(d_dout));
Expand Down
102 changes: 101 additions & 1 deletion dev/cuda/matmul_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,88 @@ __global__ void add_bias(float* out, const float* bias, int B, int T, int OC) {
}
}

// kernel 4: semi-efficient handwritten kernel
// see trimat_forward.cu for some intermediate development steps
__device__ float4 ld_vec(const float* address) {
return *reinterpret_cast<const float4*>(address);
}

__device__ void st_vec(float* address, float4 val) {
*reinterpret_cast<float4*>(address) = val;
}

__global__ void __launch_bounds__(16*16) matmul_forward_kernel4(float* out,
const float* inp, const float* weight, const float* bias,
int C, int OC) {
// out is (B,T,OC). OC is short for "output channels", e.g. OC = 4 * C
// inp is (B,T,C), weight is (OC, C), bias is (OC)
// each thread handles 8x8 elements; each block 128 by 128 elements.
int oc = 8*(blockIdx.y * blockDim.y + threadIdx.y);

// buffers to cache chunks of the input matrices
__shared__ float lhs_s[128][32];
__shared__ float rhs_s[128][32];

// adjust our pointers for the current block
inp += 128 * blockIdx.x * C;
weight += 128 * blockIdx.y * C;
out += 128 * blockIdx.x * OC + 128 * blockIdx.y;

float vals[8][8] = {};
if(bias != NULL) {
for (int i = 0; i < 8; i++) {
for (int j = 0; j < 8; j += 4) {
float4 b = ld_vec(bias + oc + j);
vals[i][j+0] = b.x;
vals[i][j+1] = b.y;
vals[i][j+2] = b.z;
vals[i][j+3] = b.w;
}
}
}

int si_start = 4*(16 * threadIdx.y + threadIdx.x);
for (int so = 0; so < C; so += 32) {
__syncthreads();
int xmod8 = threadIdx.x % 8;
int xby8 = threadIdx.x / 8;
int xo = 4 * xmod8;
for(int y = 2 * threadIdx.y + xby8; y < 128; y += 32) {
st_vec(&lhs_s[y][xo], ld_vec(inp + y * C + so + xo));
st_vec(&rhs_s[y][xo], ld_vec(weight + y * C + so + xo));
}
__syncthreads();

for (int si = si_start; si < si_start + 32; si += 4) {
float4 rhs[8];
for (int u = 0; u < 8; ++u) {
rhs[u] = ld_vec(&rhs_s[u + 8 * threadIdx.y][si % 32]);
}

for (int ii = 0; ii < 8; ++ii) {
float4 lhs = ld_vec(&lhs_s[ii + 8 * threadIdx.x][si % 32]);
for (int ji = 0; ji < 8; ++ji) {
vals[ii][ji] += lhs.x * rhs[ji].x;
vals[ii][ji] += lhs.y * rhs[ji].y;
vals[ii][ji] += lhs.z * rhs[ji].z;
vals[ii][ji] += lhs.w * rhs[ji].w;
}
}
}
}

for (int i = 0; i < 8; ++i) {
for (int j = 0; j < 8; j += 4) {
float4 result;
result.x = vals[i][j + 0];
result.y = vals[i][j + 1];
result.z = vals[i][j + 2];
result.w = vals[i][j + 3];
st_vec(out + (8*threadIdx.x+i) * OC + 8*threadIdx.y + j, result);
}
}
}

// ----------------------------------------------------------------------------
// kernel launcher

Expand Down Expand Up @@ -218,6 +300,21 @@ void matmul_forward3(float* out,
cublasCheck(cublasLtMatrixLayoutDestroy(biasLayout));
}

// handwritten, relatively efficient non-tensorcore matmul kernel
void matmul_forward4(float* out,
const float* inp, const float* weight, const float* bias,
int B, int T, int C, int OC,
int sqrt_block_size) {
// out is (B,T,OC). OC is short for "output channels", e.g. OC = 4 * C
// inp is (B,T,C), weight is (OC, C), bias is (OC)
sqrt_block_size = 16;

dim3 gridDim(ceil_div(B * T, 8*sqrt_block_size), ceil_div(OC, 8*sqrt_block_size));
dim3 blockDim(sqrt_block_size, sqrt_block_size);
matmul_forward_kernel4<<<gridDim, blockDim>>>(out, inp, weight, bias, C, OC);
cudaCheck(cudaGetLastError());
}

// kernel version dispatch
void matmul_forward(int kernel_num,
float* out,
Expand All @@ -234,6 +331,9 @@ void matmul_forward(int kernel_num,
case 3:
matmul_forward3(out, inp, weight, bias, B, T, C, OC);
break;
case 4:
matmul_forward4(out, inp, weight, bias, B, T, C, OC, sqrt_block_size);
break;
default:
printf("Invalid kernel number\n");
exit(1);
Expand All @@ -245,7 +345,7 @@ void matmul_forward(int kernel_num,
int main(int argc, char **argv) {
srand(0);

int B = 8;
int B = 32;
int T = 1024;
int C = 768;
int OC = 768 * 4; // expansion of 4, e.g. in the MLP
Expand Down
2 changes: 1 addition & 1 deletion dev/cuda/residual_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ int main(int argc, char **argv) {
float* out = (float*)malloc(B * T * C * sizeof(float));
float* inp1 = make_random_float(B * T * C);
float* inp2 = make_random_float(B * T * C);

// move to GPU
floatX* d_out;
floatX* d_inp1;
Expand Down
Loading

0 comments on commit aabb5c9

Please sign in to comment.