diff --git a/llama2.mojo b/llama2.mojo index 1d2fd22..229226d 100644 --- a/llama2.mojo +++ b/llama2.mojo @@ -1,8 +1,8 @@ from algorithm import sum -from algorithm import vectorize, parallelize, unroll +from algorithm import vectorize, parallelize, unroll, tile from builtin import string -from math import round -from memory import memset_zero, memcpy +from math import round, log2 +from memory import memset_zero, memcpy, stack_allocation from memory.buffer import Buffer from memory.unsafe import DTypePointer from random import rand @@ -19,7 +19,7 @@ import time var workers = 0 -alias nelts = (4*simdwidthof[DType.float32]()) +alias nelts = (4 * simdwidthof[DType.float32]()) alias PointerString = Pointer[UInt8] alias BufferPtrType = DTypePointer[DType.uint8] @@ -371,7 +371,6 @@ struct RunState: var key_cache: TensorF32 # (layer, seq_len, dim) var value_cache: TensorF32 # (layer, seq_len, dim) - fn __init__(inout self, config: Config) raises: self.x = TensorF32(config.dim) self.xb = TensorF32(config.dim) @@ -449,10 +448,10 @@ fn read_file(file_name: String, inout buf: FileBuf) raises: let cp_buf: BufferPtrType = BufferPtrType.alloc(cp_size) let data_ptr = data._as_ptr().bitcast[DType.uint8]() - + for i in range(cp_size): - cp_buf.store(i,data_ptr.load(i)) - + cp_buf.store(i, data_ptr.load(i)) + # don't free data _ = data @@ -571,41 +570,47 @@ fn batch_matmul[ rows: Int, cols: Int, ): + alias nelts_list = VariadicList(128, 64, 32, 16, 8, 4, 2, 1) + @parameter fn compute_row(i: Int): - var tmp = StaticTuple[n, SIMD[DType.float32, nelts]]() + var tmp = StaticTuple[n, BufferPtrFloat32]() + @parameter fn init[k: Int](): - tmp[k] = SIMD[DType.float32, nelts](0) + tmp[k] = stack_allocation[nelts, DType.float32]() + memset_zero(tmp[k], nelts) + unroll[n, init]() let row_offset = i * cols - @parameter - fn dot[_nelts: Int](j: Int): - if _nelts < nelts: # take care of tail array elements with length < nelts - let a = A.simd_load[_nelts](j) + var j = 0 - @parameter - fn _multiply_tail[k: Int](): - tmp[k][0] += ( - a * B[k].simd_load[_nelts](row_offset + j) - ).reduce_add() - - unroll[n, _multiply_tail]() - else: - let a = A.simd_load[nelts](j) - - @parameter - fn _multiply[k: Int](): - tmp[k] += a * B[k].simd_load[nelts](row_offset + j) - - unroll[n, _multiply]() - - vectorize[nelts, dot](cols) + @parameter + fn dot[z: Int](): + # we want the list to only contain nelts value that are 4 * simdwidth or less, if we use bigger values we get undefined behavior + @parameter + if nelts_list[z] <= nelts: + let range = cols - cols % nelts_list[z] + while j < range: + let a = A.simd_load[nelts_list[z]](j) + + @parameter + fn _multiply_tail[k: Int](): + tmp[k].simd_store[nelts_list[z]]( + 0, + tmp[k].simd_load[nelts_list[z]](0) + + a * B[k].simd_load[nelts_list[z]](row_offset + j), + ) + + unroll[n, _multiply_tail]() + j += nelts_list[z] + + unroll[len(nelts_list), dot]() @parameter fn _reduce[k: Int](): - C[k].store(i, tmp[k].reduce_add()) + C[k].store(i, tmp[k].simd_load[nelts](0).reduce_add()) unroll[n, _reduce]() @@ -643,7 +648,9 @@ fn matmul(C: TensorSlice, A: TensorF32, B: TensorSlice) raises: # B (d,n) @ A (n,) -> C (d,) matmul_dimension_checks(A.shape(), B.shape()) batch_matmul[1]( - StaticTuple[1, BufferPtrFloat32](C.data(),), + StaticTuple[1, BufferPtrFloat32]( + C.data(), + ), A.data(), StaticTuple[1, BufferPtrFloat32](B.data()), B.dim(0), @@ -671,8 +678,9 @@ fn rope_rotation_llama( ) -> None: # stories model, llama2 let head_size = config.head_size + @parameter - fn head_loop(i:Int): + fn head_loop(i: Int): # Simple vectorization with (head_size // 2) steps gave junk transformer output. # Maybe because the nelt ranges end up overlapping between the steps. for j in range(0, config.head_size, 2): @@ -687,8 +695,8 @@ fn rope_rotation_llama( let k1 = state.k[i * head_size + j + 1] state.k[i * head_size + j] = k0 * fcr - k1 * fci state.k[i * head_size + j + 1] = k0 * fci + k1 * fcr - parallelize[head_loop](config.n_heads, workers) + parallelize[head_loop](config.n_heads, workers) @always_inline @@ -755,7 +763,7 @@ fn transformer( # Multihead attention. Iterate over all heads in parallel. @parameter - fn loop_over_heads(h:Int): + fn loop_over_heads(h: Int): # Get the query vector for this head let q_offset = h * head_size @@ -1020,8 +1028,17 @@ fn main() raises: var tok = Tokenizer(config.vocab_size, tbuf) # print the layers number and vocab size - print("checkpoint size: ", fbuf.size, "[", fbuf.size // 1024 // 1024, "MB ]", - "| n layers:", config.n_layers, "| vocab size:", tok.vocab_size) + print( + "checkpoint size: ", + fbuf.size, + "[", + fbuf.size // 1024 // 1024, + "MB ]", + "| n layers:", + config.n_layers, + "| vocab size:", + tok.vocab_size, + ) # Create and initialize the application RunState var state = RunState(config)