From 2794386a74f3d22693d9f89076d1a2353bcabcb3 Mon Sep 17 00:00:00 2001 From: NKspartan Date: Sat, 18 Nov 2023 23:25:51 -0600 Subject: [PATCH 1/3] Changed vectorize function for tile_vectorize_list in matmul --- llama2.mojo | 78 ++++++++++++++++++++++++++++++----------------------- 1 file changed, 45 insertions(+), 33 deletions(-) diff --git a/llama2.mojo b/llama2.mojo index 1d2fd22..7b89189 100644 --- a/llama2.mojo +++ b/llama2.mojo @@ -1,8 +1,8 @@ from algorithm import sum -from algorithm import vectorize, parallelize, unroll +from algorithm import vectorize, parallelize, unroll, tile from builtin import string from math import round -from memory import memset_zero, memcpy +from memory import memset_zero, memcpy, stack_allocation from memory.buffer import Buffer from memory.unsafe import DTypePointer from random import rand @@ -19,7 +19,7 @@ import time var workers = 0 -alias nelts = (4*simdwidthof[DType.float32]()) +alias nelts = (4 * simdwidthof[DType.float32]()) alias PointerString = Pointer[UInt8] alias BufferPtrType = DTypePointer[DType.uint8] @@ -371,7 +371,6 @@ struct RunState: var key_cache: TensorF32 # (layer, seq_len, dim) var value_cache: TensorF32 # (layer, seq_len, dim) - fn __init__(inout self, config: Config) raises: self.x = TensorF32(config.dim) self.xb = TensorF32(config.dim) @@ -449,10 +448,10 @@ fn read_file(file_name: String, inout buf: FileBuf) raises: let cp_buf: BufferPtrType = BufferPtrType.alloc(cp_size) let data_ptr = data._as_ptr().bitcast[DType.uint8]() - + for i in range(cp_size): - cp_buf.store(i,data_ptr.load(i)) - + cp_buf.store(i, data_ptr.load(i)) + # don't free data _ = data @@ -571,41 +570,42 @@ fn batch_matmul[ rows: Int, cols: Int, ): + alias nelts_list = VariadicList(32, 16, 8, 4, 2, 1) + alias stack_size = 32 + @parameter fn compute_row(i: Int): - var tmp = StaticTuple[n, SIMD[DType.float32, nelts]]() + var tmp = StaticTuple[n, BufferPtrFloat32]() + @parameter fn init[k: Int](): - tmp[k] = SIMD[DType.float32, nelts](0) + tmp[k] = stack_allocation[stack_size, DType.float32]() + memset_zero(tmp[k], stack_size) + unroll[n, init]() let row_offset = i * cols @parameter fn dot[_nelts: Int](j: Int): - if _nelts < nelts: # take care of tail array elements with length < nelts - let a = A.simd_load[_nelts](j) + # take care of tail array elements with length < nelts + let a = A.simd_load[_nelts](j) - @parameter - fn _multiply_tail[k: Int](): - tmp[k][0] += ( - a * B[k].simd_load[_nelts](row_offset + j) - ).reduce_add() + @parameter + fn _multiply_tail[k: Int](): + tmp[k].simd_store[_nelts]( + 0, + tmp[k].simd_load[_nelts](0) + + a * B[k].simd_load[_nelts](row_offset + j), + ) - unroll[n, _multiply_tail]() - else: - let a = A.simd_load[nelts](j) + unroll[n, _multiply_tail]() - @parameter - fn _multiply[k: Int](): - tmp[k] += a * B[k].simd_load[nelts](row_offset + j) - - unroll[n, _multiply]() - - vectorize[nelts, dot](cols) + # vectorize[nelts, dot](cols) + tile[dot, nelts_list](0, cols) @parameter fn _reduce[k: Int](): - C[k].store(i, tmp[k].reduce_add()) + C[k].store(i, tmp[k].simd_load[nelts](0).reduce_add()) unroll[n, _reduce]() @@ -643,7 +643,9 @@ fn matmul(C: TensorSlice, A: TensorF32, B: TensorSlice) raises: # B (d,n) @ A (n,) -> C (d,) matmul_dimension_checks(A.shape(), B.shape()) batch_matmul[1]( - StaticTuple[1, BufferPtrFloat32](C.data(),), + StaticTuple[1, BufferPtrFloat32]( + C.data(), + ), A.data(), StaticTuple[1, BufferPtrFloat32](B.data()), B.dim(0), @@ -671,8 +673,9 @@ fn rope_rotation_llama( ) -> None: # stories model, llama2 let head_size = config.head_size + @parameter - fn head_loop(i:Int): + fn head_loop(i: Int): # Simple vectorization with (head_size // 2) steps gave junk transformer output. # Maybe because the nelt ranges end up overlapping between the steps. for j in range(0, config.head_size, 2): @@ -687,8 +690,8 @@ fn rope_rotation_llama( let k1 = state.k[i * head_size + j + 1] state.k[i * head_size + j] = k0 * fcr - k1 * fci state.k[i * head_size + j + 1] = k0 * fci + k1 * fcr - parallelize[head_loop](config.n_heads, workers) + parallelize[head_loop](config.n_heads, workers) @always_inline @@ -755,7 +758,7 @@ fn transformer( # Multihead attention. Iterate over all heads in parallel. @parameter - fn loop_over_heads(h:Int): + fn loop_over_heads(h: Int): # Get the query vector for this head let q_offset = h * head_size @@ -1020,8 +1023,17 @@ fn main() raises: var tok = Tokenizer(config.vocab_size, tbuf) # print the layers number and vocab size - print("checkpoint size: ", fbuf.size, "[", fbuf.size // 1024 // 1024, "MB ]", - "| n layers:", config.n_layers, "| vocab size:", tok.vocab_size) + print( + "checkpoint size: ", + fbuf.size, + "[", + fbuf.size // 1024 // 1024, + "MB ]", + "| n layers:", + config.n_layers, + "| vocab size:", + tok.vocab_size, + ) # Create and initialize the application RunState var state = RunState(config) From b461665a921c339b544c5895b43b51dbf106e3db Mon Sep 17 00:00:00 2001 From: NKspartan Date: Sun, 19 Nov 2023 19:37:19 -0600 Subject: [PATCH 2/3] Changed tile function in batch_matmul --- llama2.mojo | 58 ++++++++++++++++++++++++++++++++++------------------- 1 file changed, 37 insertions(+), 21 deletions(-) diff --git a/llama2.mojo b/llama2.mojo index 7b89189..3d29b3c 100644 --- a/llama2.mojo +++ b/llama2.mojo @@ -1,7 +1,7 @@ from algorithm import sum from algorithm import vectorize, parallelize, unroll, tile from builtin import string -from math import round +from math import round, log2 from memory import memset_zero, memcpy, stack_allocation from memory.buffer import Buffer from memory.unsafe import DTypePointer @@ -570,8 +570,21 @@ fn batch_matmul[ rows: Int, cols: Int, ): - alias nelts_list = VariadicList(32, 16, 8, 4, 2, 1) - alias stack_size = 32 + alias nelts_list_size = log2[DType.float64, 1](nelts).to_int() + 1 + + @parameter + fn create_nelts_list() -> StaticTuple[nelts_list_size, Int]: + var nelts_list = StaticTuple[nelts_list_size, Int]() + + @parameter + fn create_nelts_list_helper[i: Int](): + nelts_list[i] = nelts >> i + + unroll[nelts_list_size, create_nelts_list_helper]() + + return nelts_list + + alias nelts_list = create_nelts_list() # we want the list to only contain nelts value that are 4 * simdwidth or less, if we use bigger values we get undefined behavior @parameter fn compute_row(i: Int): @@ -579,29 +592,32 @@ fn batch_matmul[ @parameter fn init[k: Int](): - tmp[k] = stack_allocation[stack_size, DType.float32]() - memset_zero(tmp[k], stack_size) + tmp[k] = stack_allocation[nelts_list[0], DType.float32]() + memset_zero(tmp[k], nelts_list[0]) unroll[n, init]() let row_offset = i * cols + var j = 0 + @parameter - fn dot[_nelts: Int](j: Int): - # take care of tail array elements with length < nelts - let a = A.simd_load[_nelts](j) - - @parameter - fn _multiply_tail[k: Int](): - tmp[k].simd_store[_nelts]( - 0, - tmp[k].simd_load[_nelts](0) - + a * B[k].simd_load[_nelts](row_offset + j), - ) - - unroll[n, _multiply_tail]() - - # vectorize[nelts, dot](cols) - tile[dot, nelts_list](0, cols) + fn dot[z: Int](): + let range = cols - cols % nelts_list[z] + while j < range: + let a = A.simd_load[nelts_list[z]](j) + + @parameter + fn _multiply_tail[k: Int](): + tmp[k].simd_store[nelts_list[z]]( + 0, + tmp[k].simd_load[nelts_list[z]](0) + + a * B[k].simd_load[nelts_list[z]](row_offset + j), + ) + + unroll[n, _multiply_tail]() + j += nelts_list[z] + + unroll[nelts_list_size, dot]() @parameter fn _reduce[k: Int](): From 5a6bf36855144be79fa00dabef25af06f4318fed Mon Sep 17 00:00:00 2001 From: NKspartan Date: Sun, 19 Nov 2023 20:10:08 -0600 Subject: [PATCH 3/3] Simplified the change to tile in batch_matmul --- llama2.mojo | 55 +++++++++++++++++++++-------------------------------- 1 file changed, 22 insertions(+), 33 deletions(-) diff --git a/llama2.mojo b/llama2.mojo index 3d29b3c..229226d 100644 --- a/llama2.mojo +++ b/llama2.mojo @@ -570,21 +570,7 @@ fn batch_matmul[ rows: Int, cols: Int, ): - alias nelts_list_size = log2[DType.float64, 1](nelts).to_int() + 1 - - @parameter - fn create_nelts_list() -> StaticTuple[nelts_list_size, Int]: - var nelts_list = StaticTuple[nelts_list_size, Int]() - - @parameter - fn create_nelts_list_helper[i: Int](): - nelts_list[i] = nelts >> i - - unroll[nelts_list_size, create_nelts_list_helper]() - - return nelts_list - - alias nelts_list = create_nelts_list() # we want the list to only contain nelts value that are 4 * simdwidth or less, if we use bigger values we get undefined behavior + alias nelts_list = VariadicList(128, 64, 32, 16, 8, 4, 2, 1) @parameter fn compute_row(i: Int): @@ -592,8 +578,8 @@ fn batch_matmul[ @parameter fn init[k: Int](): - tmp[k] = stack_allocation[nelts_list[0], DType.float32]() - memset_zero(tmp[k], nelts_list[0]) + tmp[k] = stack_allocation[nelts, DType.float32]() + memset_zero(tmp[k], nelts) unroll[n, init]() let row_offset = i * cols @@ -602,22 +588,25 @@ fn batch_matmul[ @parameter fn dot[z: Int](): - let range = cols - cols % nelts_list[z] - while j < range: - let a = A.simd_load[nelts_list[z]](j) - - @parameter - fn _multiply_tail[k: Int](): - tmp[k].simd_store[nelts_list[z]]( - 0, - tmp[k].simd_load[nelts_list[z]](0) - + a * B[k].simd_load[nelts_list[z]](row_offset + j), - ) - - unroll[n, _multiply_tail]() - j += nelts_list[z] - - unroll[nelts_list_size, dot]() + # we want the list to only contain nelts value that are 4 * simdwidth or less, if we use bigger values we get undefined behavior + @parameter + if nelts_list[z] <= nelts: + let range = cols - cols % nelts_list[z] + while j < range: + let a = A.simd_load[nelts_list[z]](j) + + @parameter + fn _multiply_tail[k: Int](): + tmp[k].simd_store[nelts_list[z]]( + 0, + tmp[k].simd_load[nelts_list[z]](0) + + a * B[k].simd_load[nelts_list[z]](row_offset + j), + ) + + unroll[n, _multiply_tail]() + j += nelts_list[z] + + unroll[len(nelts_list), dot]() @parameter fn _reduce[k: Int]():