Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Idea for matmul tiling #67

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 84 additions & 39 deletions llama2.mojo
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from algorithm import sum
from algorithm import vectorize, parallelize, unroll
from algorithm import vectorize, parallelize, unroll, vectorize_unroll
from algorithm import Static1DTileUnitFunc as Tile1DFunc
from builtin import string
from math import round
from memory import memset_zero, memcpy
from memory import memset_zero, memcpy, stack_allocation
from memory.buffer import Buffer
from memory.unsafe import DTypePointer
from random import rand
Expand All @@ -19,7 +20,7 @@ import time

var workers = 0

alias nelts = (4*simdwidthof[DType.float32]())
alias nelts = (4 * simdwidthof[DType.float32]())

alias PointerString = Pointer[UInt8]
alias BufferPtrType = DTypePointer[DType.uint8]
Expand Down Expand Up @@ -371,7 +372,6 @@ struct RunState:
var key_cache: TensorF32 # (layer, seq_len, dim)
var value_cache: TensorF32 # (layer, seq_len, dim)


fn __init__(inout self, config: Config) raises:
self.x = TensorF32(config.dim)
self.xb = TensorF32(config.dim)
Expand Down Expand Up @@ -449,10 +449,10 @@ fn read_file(file_name: String, inout buf: FileBuf) raises:
let cp_buf: BufferPtrType = BufferPtrType.alloc(cp_size)

let data_ptr = data._as_ptr().bitcast[DType.uint8]()

for i in range(cp_size):
cp_buf.store(i,data_ptr.load(i))
cp_buf.store(i, data_ptr.load(i))

# don't free data
_ = data

Expand Down Expand Up @@ -561,6 +561,18 @@ fn softmax(inout x: TensorF32, start: Int, end: Int):
vectorize[nelts, _norm](end - start)


@always_inline
fn tile_parallel[tiled_fn: Tile1DFunc, tile: Int](end: Int):
fn row(i: Int):
let io = i * tile
tiled_fn[tile](io)

parallelize[row](end // tile, workers)

if end % tile != 0:
tiled_fn[tile](end - tile)


@always_inline
fn batch_matmul[
n: Int
Expand All @@ -571,45 +583,66 @@ fn batch_matmul[
rows: Int,
cols: Int,
):
alias nelts = simdwidthof[DType.float32]()

alias tile_j = 2**n * nelts
alias tile_i = 8 // n

alias stack_size = tile_i * nelts

@parameter
fn compute_row(i: Int):
var tmp = StaticTuple[n, SIMD[DType.float32, nelts]]()
@parameter
fn init[k: Int]():
tmp[k] = SIMD[DType.float32, nelts](0)
unroll[n, init]()
let row_offset = i * cols
fn calc_tiles_row[tile_i: Int](io: Int):
var accumulator = StaticTuple[n, BufferPtrFloat32]()

@parameter
fn dot[_nelts: Int](j: Int):
if _nelts < nelts: # take care of tail array elements with length < nelts
let a = A.simd_load[_nelts](j)
fn _init[k: Int]():
accumulator[k] = stack_allocation[stack_size, DType.float32]()
memset_zero(accumulator[k], stack_size)

unroll[n, _init]()

@parameter
fn calc_cols[tile_j_unroll: Int](jo: Int, tile_j: Int):
@parameter
fn _batch[k: Int]():
@parameter
fn _multiply_tail[k: Int]():
tmp[k][0] += (
a * B[k].simd_load[_nelts](row_offset + j)
).reduce_add()
fn calc_row[i: Int]():
let row_offset_c = i * nelts
let row_offset_b = (io + i) * cols

unroll[n, _multiply_tail]()
else:
let a = A.simd_load[nelts](j)
@parameter
fn calc_col[_nelts: Int](j: Int):
accumulator[k].simd_store[_nelts](
row_offset_c,
accumulator[k].simd_load[_nelts](row_offset_c)
+ A.simd_load[_nelts](jo + j)
* B[k].simd_load[_nelts](row_offset_b + jo + j),
)

@parameter
fn _multiply[k: Int]():
tmp[k] += a * B[k].simd_load[nelts](row_offset + j)
vectorize_unroll[nelts, tile_j_unroll, calc_col](tile_j)

unroll[tile_i, calc_row]()

unroll[n, _batch]()

unroll[n, _multiply]()
for jo in range(0, cols - cols % tile_j, tile_j):
calc_cols[tile_j // nelts](jo, tile_j)

vectorize[nelts, dot](cols)
calc_cols[1](cols - cols % tile_j, cols % tile_j)

@parameter
fn _reduce[k: Int]():
C[k].store(i, tmp[k].reduce_add())
fn _copy_values[k: Int]():
@parameter
fn _reduce[i: Int]():
C[k].store(
io + i, accumulator[k].simd_load[nelts](i * nelts).reduce_add()
)

unroll[n, _reduce]()
unroll[tile_i, _reduce]()

parallelize[compute_row](rows, workers)
unroll[n, _copy_values]()

tile_parallel[calc_tiles_row, tile_i](rows)


@always_inline
Expand Down Expand Up @@ -643,7 +676,9 @@ fn matmul(C: TensorSlice, A: TensorF32, B: TensorSlice) raises:
# B (d,n) @ A (n,) -> C (d,)
matmul_dimension_checks(A.shape(), B.shape())
batch_matmul[1](
StaticTuple[1, BufferPtrFloat32](C.data(),),
StaticTuple[1, BufferPtrFloat32](
C.data(),
),
A.data(),
StaticTuple[1, BufferPtrFloat32](B.data()),
B.dim(0),
Expand Down Expand Up @@ -671,8 +706,9 @@ fn rope_rotation_llama(
) -> None:
# stories model, llama2
let head_size = config.head_size

@parameter
fn head_loop(i:Int):
fn head_loop(i: Int):
# Simple vectorization with (head_size // 2) steps gave junk transformer output.
# Maybe because the nelt ranges end up overlapping between the steps.
for j in range(0, config.head_size, 2):
Expand All @@ -687,8 +723,8 @@ fn rope_rotation_llama(
let k1 = state.k[i * head_size + j + 1]
state.k[i * head_size + j] = k0 * fcr - k1 * fci
state.k[i * head_size + j + 1] = k0 * fci + k1 * fcr
parallelize[head_loop](config.n_heads, workers)

parallelize[head_loop](config.n_heads, workers)


@always_inline
Expand Down Expand Up @@ -755,7 +791,7 @@ fn transformer(

# Multihead attention. Iterate over all heads in parallel.
@parameter
fn loop_over_heads(h:Int):
fn loop_over_heads(h: Int):
# Get the query vector for this head
let q_offset = h * head_size

Expand Down Expand Up @@ -1020,8 +1056,17 @@ fn main() raises:
var tok = Tokenizer(config.vocab_size, tbuf)

# print the layers number and vocab size
print("checkpoint size: ", fbuf.size, "[", fbuf.size // 1024 // 1024, "MB ]",
"| n layers:", config.n_layers, "| vocab size:", tok.vocab_size)
print(
"checkpoint size: ",
fbuf.size,
"[",
fbuf.size // 1024 // 1024,
"MB ]",
"| n layers:",
config.n_layers,
"| vocab size:",
tok.vocab_size,
)

# Create and initialize the application RunState
var state = RunState(config)
Expand Down