-
Notifications
You must be signed in to change notification settings - Fork 639
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement cached inference #409
Implement cached inference #409
Conversation
Wow! |
@borzunov This is tremendous Alexander! You made it fit with so many of the other features in the repository too. Truly a wizard! |
Thank you :) |
This is awesome but I'm having issues loading old models. It's caused by the extra layers from the cached classes. |
@EmaadKhwaja That's true, sorry for the inconvenience. You can use the code like this to load older models: state_dict = torch.load('model_state.pt')
state_dict = OrderedDict([(key.replace('to_qkv', 'fn.fn.to_qkv').replace('to_out', 'fn.fn.to_out'), value)
for key, value in state_dict.items()])
print(model.load_state_dict(state_dict)) Other mismatches (if any) can be fixed in the same way. |
Note: This PR is based on the branch from #408, and the diff shows changes from both PRs by default. See this diff to keep only the changes related to cached inference.
Description
This PR optimizes inference, reducing its time complexity from O(n^3) to O(n^2), where n is the sequence length.
When someone runs
generate_images(..., use_cache=True)
:The text is processed as before, but the outputs of all
Attention
andPreShiftToken
layers are cached.The image is generated token-by-token. For each token, we only pass the last token through the network (as if
seq_len = 1
). The layers involving token interaction (Attention
andPreShiftToken
) look up the necessary info (e.g., keys and values of previous tokens) in cache.This implementation should be more efficient than ai-forever/ru-dalle#12 since it doesn't concat the outputs of FFN layers and doesn't create excess tensors with
seq_len
dimension.Supporting sparse attention
It's cumbersome to implement caching in sparse attention layers efficiently, so they are not cached by default. If you'd like to use the cached inference for sparse attention layers, you can load their weights into a full attention layer and use an appropriate mask to simulate a sparse layer.
The sparse implementation (like
SparseAxialCausalAttention(...)
) is faster for training, but the masked implementation (likeAttention(..., static_mask=...)
) is much faster for inference (thanks to caching).To enable the masked implementation, create the model as
DALLE(..., optimize_for_inference=True)
.Experiments
In the "Training Transformers Together" demo, we train a model created as
DALLE(..., optimize_for_inference=False)
(see details on model config in Add options for weight sharing #408).In our Colab for inference, we load its latest checkpoint to a model created as
DALLE(..., optimize_for_inference=True)
(actually, there's an older param name, but it doesn't matter).Next, running
generate_images(..., use_cache=True)
gives > 9x speed-up comparing to the inference time without code from this PR. The results seem to be equivalent.