-
Notifications
You must be signed in to change notification settings - Fork 27
add mamba causal-conv1d-update kernel #48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
thoangtrvn
wants to merge
2
commits into
pytorch-labs:main
Choose a base branch
from
thoangtrvn:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,9 @@ | |
from einops import rearrange | ||
from typing import Literal, Optional | ||
|
||
# vllm/attention/backends/utils.py | ||
PAD_SLOT_ID = -1 | ||
|
||
|
||
@triton.autotune( | ||
configs=[ | ||
|
@@ -328,3 +331,324 @@ def causal_conv1d_fn( | |
final_states_out, | ||
activation, | ||
) | ||
|
||
|
||
@triton.autotune( | ||
configs=[ | ||
triton.Config({"BLOCK_N": 128}, num_stages=3, num_warps=8), | ||
triton.Config({"BLOCK_N": 64}, num_stages=3, num_warps=8), | ||
triton.Config({"BLOCK_N": 32}, num_stages=3, num_warps=8), | ||
triton.Config({"BLOCK_N": 16}, num_stages=3, num_warps=8), | ||
], | ||
key=["dim"], | ||
restore_value=["conv_state_ptr", "x_ptr"], | ||
) | ||
@triton.jit() | ||
def _causal_conv1d_update_kernel( | ||
# Pointers to matrices | ||
x_ptr, # (batch, dim, seqlen) | ||
w_ptr, # (dim, width) | ||
bias_ptr, | ||
conv_state_ptr, | ||
cache_seqlens_ptr, | ||
conv_state_indices_ptr, | ||
o_ptr, # (batch, dim, seqlen) | ||
# Matrix dimensions | ||
batch, | ||
dim, | ||
seqlen, | ||
state_len, | ||
num_cache_lines, | ||
# Strides | ||
stride_x_seq, # stride to get to next sequence, | ||
stride_x_dim, # stride to get to next feature-value, | ||
stride_x_token, # stride to get to next token (same feature-index, same sequence-index) | ||
stride_weight_dim, # stride to get to next dim-axis value | ||
stride_weight_width, # stride to get to next width-axis value | ||
stride_conv_state_seq, | ||
stride_conv_state_dim, | ||
stride_conv_state_tok, | ||
stride_o_seq, | ||
stride_o_dim, | ||
stride_o_token, | ||
# others | ||
pad_slot_id, | ||
# Meta-parameters | ||
HAS_BIAS: tl.constexpr, | ||
KERNEL_WIDTH: tl.constexpr, | ||
SILU_ACTIVATION: tl.constexpr, | ||
IS_CONTINUOUS_BATCHING: tl.constexpr, | ||
IS_CIRCULAR_BUFFER: tl.constexpr, | ||
NP2_STATELEN: tl.constexpr, | ||
USE_PAD_SLOT: tl.constexpr, | ||
BLOCK_N: tl.constexpr, | ||
): | ||
idx_seq = tl.program_id(0) | ||
if idx_seq >= batch: | ||
return | ||
|
||
idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) | ||
|
||
w_base = w_ptr + (idx_feats * stride_weight_dim) | ||
if IS_CIRCULAR_BUFFER: | ||
cache_seqlen = tl.load(cache_seqlens_ptr + idx_seq) # modulo later | ||
else: | ||
cache_seqlen = 0 | ||
# store output data at the corresponding tokens (BLOCK_M of them) and feature-indices (BLOCK_N of them) in these tokens | ||
if IS_CONTINUOUS_BATCHING: | ||
conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq) | ||
else: | ||
conv_state_batch_coord = idx_seq | ||
if USE_PAD_SLOT: | ||
if conv_state_batch_coord == pad_slot_id: | ||
# not processing | ||
return | ||
conv_state_base = ( | ||
conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + (idx_feats * stride_conv_state_dim) | ||
) # [BLOCK_N,] | ||
|
||
for idx_token in range(seqlen): | ||
x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim) # [BLOCK_N, ] | ||
|
||
if HAS_BIAS: | ||
bias = bias_ptr + idx_feats | ||
mask_bias = idx_feats < dim | ||
acc = tl.load(bias, mask=mask_bias, other=0.0).to(tl.float32) # [BLOCK_N] | ||
else: | ||
acc = tl.zeros((BLOCK_N,), dtype=tl.float32) | ||
PADDING_W = KERNEL_WIDTH - 1 | ||
for j in range(KERNEL_WIDTH): | ||
# the token index to multiply with kernel[:, 0], given kernel with width-columns, i.e. kernel[:, 0..(width-1)] | ||
idx_x_w = j - PADDING_W + idx_token | ||
x_ptrs = x_base + (idx_x_w * stride_x_token) # [BLOCK_N] | ||
mask_x = (idx_x_w >= 0) & (idx_x_w < seqlen) & (idx_feats < dim) | ||
if IS_CIRCULAR_BUFFER: | ||
assert 0 # TUAN TODO: double check the logic - it seems correct | ||
conv_state_ptrs = ( | ||
conv_state_base + (((idx_x_w + cache_seqlen) % state_len) * stride_conv_state_tok)[:, None] | ||
) # [BLOCK_M, BLOCK_N] | ||
else: | ||
conv_state_ptrs = conv_state_base + ((idx_x_w + state_len) * stride_conv_state_tok) # [BLOCK_N] | ||
mask_w = (conv_state_batch_coord < num_cache_lines) & (idx_x_w < 0) & (idx_feats < dim) | ||
conv_state = tl.load(conv_state_ptrs, mask_w, 0.0) | ||
matrix_x = tl.load(x_ptrs, mask=mask_x, other=conv_state) | ||
|
||
w_ptrs = w_base + (j * stride_weight_width) # [BLOCK_N] tensor | ||
mask_w = idx_feats < dim | ||
matrix_w = tl.load(w_ptrs, mask_w, other=0.0) | ||
acc += matrix_x * matrix_w # [BLOCK_N] | ||
|
||
if SILU_ACTIVATION: | ||
acc = acc / (1 + tl.exp(-acc)) | ||
mask = (idx_token < seqlen) & (idx_feats < dim) # sequence-index # token-index # feature-index | ||
o_ptrs = o_ptr + (idx_seq * stride_o_seq) + (idx_token * stride_o_token) + (idx_feats * stride_o_dim) | ||
tl.store(o_ptrs, acc, mask=mask) | ||
|
||
if IS_CIRCULAR_BUFFER: | ||
# TODO: | ||
assert 0 | ||
else: | ||
idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] | ||
|
||
conv_state_ptrs_source = ( | ||
conv_state_ptr | ||
+ (conv_state_batch_coord * stride_conv_state_seq) | ||
+ (idx_feats * stride_conv_state_dim)[None, :] | ||
+ ((idx_tokens + seqlen) * stride_conv_state_tok)[:, None] | ||
) # [BLOCK_M, BLOCK_N] | ||
mask = ( | ||
(conv_state_batch_coord < num_cache_lines) | ||
& ((idx_tokens + seqlen) < state_len)[:, None] | ||
& (idx_feats < dim)[None, :] | ||
) | ||
conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) | ||
|
||
VAL = state_len - seqlen | ||
x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim)[None, :] # [1, BLOCK_N] | ||
|
||
x_ptrs = x_base + ((idx_tokens - VAL) * stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] | ||
|
||
mask_x = ( | ||
(idx_tokens - VAL >= 0)[:, None] & (idx_tokens - VAL < seqlen)[:, None] & (idx_feats < dim)[None, :] | ||
) # token-index # token-index # feature-index | ||
loaded_x = tl.load(x_ptrs, mask_x, 0.0) | ||
tl.debug_barrier() | ||
|
||
new_conv_state = tl.where(mask, conv_state, loaded_x) | ||
conv_state_ptrs_target = conv_state_base + (idx_tokens * stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N] | ||
mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] | ||
tl.store(conv_state_ptrs_target, new_conv_state, mask) | ||
|
||
|
||
def causal_conv1d_update( | ||
x, | ||
conv_state, | ||
weight, | ||
bias=None, | ||
activation: Optional[Literal["silu", "swish"]] = None, | ||
cache_seqlens: Optional[torch.Tensor] = None, | ||
conv_state_indices: Optional[torch.Tensor] = None, | ||
pad_slot_id: int = None, | ||
): | ||
""" | ||
x: (batch, dim) or (batch, dim, seqlen) | ||
new tokens whose causal-conv-1d need to be computed | ||
conv_state: (..., dim, state_len), where state_len >= width - 1 | ||
(function as `init_state` in causal_conv1d_fn API) | ||
hold the previous `state_len` tokens that we can use to compute causal-conv-1d of new tokens from 'x' | ||
* if `conv_sate_indices` is provided: behave like continuous batching mode | ||
* if `cache_seqlens` is provided: also behave like a circular buffer | ||
============== | ||
[in standard batching, naturally we expect conv_state[i] is used for x[i] with i is sequence-index | ||
[in continuous batching, the corresponding prior data for sequence x[i] is | ||
NOT NECESSARY from conv_state[i]; | ||
BUT CAN BE from conv_state[conv_state_indices[i]] | ||
given i=batch_id=sequence_id | ||
IN OTHER WORDS: conv_state[j] | x[i] | ||
with j = i [if conv_state_indices is NOne | ||
with j = conv_state_indices[i] otherwise | ||
] | ||
[NOTE: can be used as a circular buffer if `cache_seqlens` is provided] | ||
weight: (dim, width) | ||
(causal) 1d conv kernel | ||
bias: (dim,) | ||
cache_seqlens: (batch,), dtype int32. | ||
[ PRIOR: | ||
i.e. [conv_state[j][k] | x[i][0] ] | ||
] | ||
Hold the token-index (3rd axis) in the `conv_state` where we ... | ||
If not None, the conv_state is treated as a circular buffer. | ||
The conv_state will be updated by copying x to the conv_state starting at the index | ||
@cache_seqlens % state_len. | ||
conv_state_indices: (batch,), dtype int32 | ||
If present, then it is used to extract the row in `conv_state` to be used with corresponding sequence x[i] | ||
i.e. for the given sequence i-th, and j-th is the index in `conv_state` where to get data to combine with x[i] for computing causal-1d-conv | ||
j = i if conv_state_indices is None | ||
j = conv_state_indices[i] otherwise | ||
i.e. [conv_state[j] | x[i] ] | ||
Useful for a continuous batching scenario. | ||
pad_slot_id: int | None | ||
If used, the constant value that we can use to compare with conv_state_indices[i], if | ||
conv_state_indices[i] == pad_slot_id, then we ignore data from that row of conv_state[conv_state_indices[i]] | ||
|
||
out: (batch, dim) or (batch, dim, seqlen) | ||
""" | ||
unsqueeze = x.dim() == 2 | ||
if unsqueeze: | ||
# make it (batch, dim, seqlen) with seqlen == 1 | ||
x = x.unsqueeze(-1) | ||
batch, dim, seqlen = x.shape | ||
_, width = weight.shape | ||
|
||
# conv_state: (..., dim, state_len), where state_len >= width - 1 | ||
state_len = conv_state.size(2) | ||
assert state_len >= width - 1 | ||
assert dim == conv_state.size(1) | ||
if conv_state_indices is None: | ||
assert conv_state.size(0) >= batch | ||
else: | ||
assert (batch,) == conv_state_indices.shape | ||
num_cache_lines = conv_state.size(0) | ||
|
||
stride_w_dim = weight.stride(0) | ||
stride_w_width = weight.stride(1) | ||
|
||
def grid(META): | ||
return ( | ||
batch, | ||
triton.cdiv(dim, META["BLOCK_N"]), | ||
) | ||
|
||
assert cache_seqlens is None # TUAN: FOR NOW (not needed for vLLM) - circular buffer # fmt:off | ||
out = torch.empty_like(x) | ||
with torch.cuda.device(x.device.index): | ||
_causal_conv1d_update_kernel[grid]( | ||
# Pointers to matrices | ||
x, | ||
weight, | ||
bias, | ||
conv_state, | ||
cache_seqlens, | ||
conv_state_indices, | ||
out, | ||
# Matrix dimensions | ||
batch, | ||
dim, | ||
seqlen, | ||
state_len, | ||
num_cache_lines, | ||
# stride | ||
x.stride(0), # X (batch, dim, seqlen) | ||
x.stride(1), | ||
x.stride(2), | ||
stride_w_dim, | ||
stride_w_width, | ||
conv_state.stride(0), | ||
conv_state.stride(1), | ||
conv_state.stride(2), | ||
out.stride(0), | ||
out.stride(1), | ||
out.stride(2), | ||
# others | ||
pad_slot_id, | ||
# META | ||
HAS_BIAS=bias is not None, | ||
KERNEL_WIDTH=width, | ||
SILU_ACTIVATION=activation in ["silu", "swish"], | ||
IS_CONTINUOUS_BATCHING=conv_state_indices is not None, | ||
IS_CIRCULAR_BUFFER=cache_seqlens is not None, | ||
NP2_STATELEN=triton.next_power_of_2(state_len), | ||
USE_PAD_SLOT=pad_slot_id is not None, | ||
) | ||
if unsqueeze: | ||
out = out.squeeze(-1) | ||
return out | ||
|
||
|
||
def causal_conv1d_update_vllm( | ||
x: torch.Tensor, | ||
conv_state: torch.Tensor, | ||
weight: torch.Tensor, | ||
bias: Optional[torch.Tensor] = None, | ||
activation: Optional[Literal["silu", "swish"]] = None, | ||
cache_seqlens: Optional[torch.Tensor] = None, | ||
conv_state_indices: Optional[torch.Tensor] = None, | ||
pad_slot_id: int = PAD_SLOT_ID, | ||
): | ||
""" | ||
x: (batch, dim) or (batch, dim, seqlen) | ||
[shape=2: single token prediction] | ||
[shape=3: multiple tokens prediction] | ||
conv_state: (batch, dim, state_len), where state_len >= width - 1 | ||
weight: (dim, width) | ||
bias: (dim,) | ||
cache_seqlens: (batch,), dtype int32. | ||
If not None, the conv_state is treated as a circular buffer. | ||
The conv_state will be updated by copying x to the conv_state | ||
starting at the index | ||
@cache_seqlens % state_len. | ||
conv_state_indices: (batch,), dtype int32 | ||
If not None, the conv_state is a larger tensor along the batch dim, | ||
and we are selecting the batch coords specified by conv_state_indices. | ||
Useful for a continuous batching scenario. | ||
pad_slot_id: int | ||
if cache_indices is passed, lets the kernel identify padded | ||
entries that will not be processed, | ||
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] | ||
in this case, the kernel will not process entries at | ||
indices 0 and 3 | ||
out: (batch, dim) or (batch, dim, seqlen) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also inconsistent - out? (vs o) |
||
""" | ||
assert cache_seqlens is None | ||
# TODO : adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o' | ||
o = causal_conv1d_update( | ||
x, | ||
conv_state, | ||
weight, | ||
bias=bias, | ||
activation=activation, | ||
cache_seqlens=cache_seqlens, | ||
conv_state_indices=conv_state_indices, | ||
pad_slot_id=pad_slot_id, | ||
) | ||
return o | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit but not a fan of using o by itself... out or output etc. makes it more clear imo. |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are returning o but it is not listed here in your function signature?