Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement cached inference #409

Merged
merged 34 commits into from
Jan 11, 2022

Conversation

borzunov
Copy link
Contributor

@borzunov borzunov commented Jan 10, 2022

Note: This PR is based on the branch from #408, and the diff shows changes from both PRs by default. See this diff to keep only the changes related to cached inference.

Description

This PR optimizes inference, reducing its time complexity from O(n^3) to O(n^2), where n is the sequence length.

When someone runs generate_images(..., use_cache=True):

  1. The text is processed as before, but the outputs of all Attention and PreShiftToken layers are cached.

  2. The image is generated token-by-token. For each token, we only pass the last token through the network (as if seq_len = 1). The layers involving token interaction (Attention and PreShiftToken) look up the necessary info (e.g., keys and values of previous tokens) in cache.

This implementation should be more efficient than ai-forever/ru-dalle#12 since it doesn't concat the outputs of FFN layers and doesn't create excess tensors with seq_len dimension.

Supporting sparse attention

It's cumbersome to implement caching in sparse attention layers efficiently, so they are not cached by default. If you'd like to use the cached inference for sparse attention layers, you can load their weights into a full attention layer and use an appropriate mask to simulate a sparse layer.

The sparse implementation (like SparseAxialCausalAttention(...)) is faster for training, but the masked implementation (like Attention(..., static_mask=...)) is much faster for inference (thanks to caching).

To enable the masked implementation, create the model as DALLE(..., optimize_for_inference=True).

Experiments

  1. In the "Training Transformers Together" demo, we train a model created as DALLE(..., optimize_for_inference=False) (see details on model config in Add options for weight sharing #408).

  2. In our Colab for inference, we load its latest checkpoint to a model created as DALLE(..., optimize_for_inference=True) (actually, there's an older param name, but it doesn't matter).

  3. Next, running generate_images(..., use_cache=True) gives > 9x speed-up comparing to the inference time without code from this PR. The results seem to be equivalent.

@borzunov borzunov mentioned this pull request Jan 10, 2022
@afiaka87
Copy link
Contributor

afiaka87 commented Jan 11, 2022

This PR optimizes inference, reducing its time complexity from O(n^3) to O(n^2), where n is the sequence length.

Wow!

Repository owner deleted a comment from kriskrisliu Jan 11, 2022
@lucidrains
Copy link
Owner

lucidrains commented Jan 11, 2022

@borzunov This is tremendous Alexander! You made it fit with so many of the other features in the repository too. Truly a wizard!

@lucidrains lucidrains merged commit 2094474 into lucidrains:main Jan 11, 2022
@borzunov
Copy link
Contributor Author

Thank you :)

@EmaadKhwaja
Copy link

This is awesome but I'm having issues loading old models. It's caused by the extra layers from the cached classes.
i.e. transformer.layers.layers.0.0.fn.fn.fn.to_qkv.weight became transformer.layers.layers.0.0.fn.fn.fn.fn.fn.to_qkv.weight

@borzunov
Copy link
Contributor Author

@EmaadKhwaja That's true, sorry for the inconvenience. You can use the code like this to load older models:

state_dict = torch.load('model_state.pt')
state_dict = OrderedDict([(key.replace('to_qkv', 'fn.fn.to_qkv').replace('to_out', 'fn.fn.to_out'), value)
                          for key, value in state_dict.items()])
print(model.load_state_dict(state_dict))

Other mismatches (if any) can be fixed in the same way.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants