Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Use Triton(ROCm) Attention backend as fallback for Turing GPUs #14071

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

Isotr0py
Copy link
Collaborator

@Isotr0py Isotr0py commented Mar 1, 2025

Related issue: #12724

  • Rename v1 ROCmAttention to TritonAttention and add fallback for Turing GPU
  • Since v1 ROCm attn backend is implemented with Triton, it can be used as fallback for older GPUs that don't support FA.

Tested on T4 GPU
$ VLLM_USE_V1=1 python examples/offline_inference/basic/generate.py --model Qwen/Qwen2.5-3B-Instruct --dtype half --enforce-eager 
--max-num-seqs 1 -tp 2
INFO 03-01 03:23:53 [__init__.py:207] Automatically detected platform cuda.
WARNING 03-01 03:23:55 [arg_utils.py:1410] Setting max_num_batched_tokens to 8192 for LLM_CLASS usage context.
WARNING 03-01 03:23:56 [config.py:2552] Casting torch.bfloat16 to torch.float16.
INFO 03-01 03:24:05 [config.py:575] This model supports multiple tasks: {'reward', 'classify', 'generate', 'score', 'embed'}. Defaulting to 'generate'.
INFO 03-01 03:24:05 [config.py:1485] Defaulting to use mp for distributed inference
INFO 03-01 03:24:05 [config.py:1660] Chunked prefill is enabled with max_num_batched_tokens=8192.
WARNING 03-01 03:24:05 [cuda.py:95] To see benefits of async output processing, enable CUDA graph. Since, enforce-eager is enabled, async output processor cannot be used
INFO 03-01 03:24:05 [core.py:50] Initializing a V1 LLM engine (v0.1.dev4784+g18e5059) with config: model='Qwen/Qwen2.5-3B-Instruct', speculative_config=None, tokenizer='Qwen/Qwen2.5-3B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=32768, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=2, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=Qwen/Qwen2.5-3B-Instruct, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=False, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"compile_sizes":[],"cudagraph_capture_sizes":[],"max_capture_size":0}
WARNING 03-01 03:24:05 [multiproc_worker_utils.py:309] Reducing Torch parallelism from 2 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
INFO 03-01 03:24:05 [custom_cache_manager.py:19] Setting Triton cache manager to: vllm.triton_utils.custom_cache_manager:CustomCacheManager
INFO 03-01 03:24:05 [shm_broadcast.py:258] vLLM message queue communication handle: Handle(local_reader_ranks=[0, 1], buffer_handle=(2, 10485760, 10, 'psm_1e082990'), local_subscribe_addr='ipc:///tmp/f7491eb8-ecd9-486b-bf1a-dd08898e8aee', remote_subscribe_addr=None, remote_addr_ipv6=False)
WARNING 03-01 03:24:06 [utils.py:2298] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes,initialize_cache not implemented in <vllm.v1.worker.gpu_worker.Worker object at 0x78579a1cc0b0>
(VllmWorker rank=0 pid=42605) INFO 03-01 03:24:06 [shm_broadcast.py:258] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_be1dd2a6'), local_subscribe_addr='ipc:///tmp/35d2b250-45a4-411a-a7bd-c25874c79c67', remote_subscribe_addr=None, remote_addr_ipv6=False)
WARNING 03-01 03:24:06 [utils.py:2298] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes,initialize_cache not implemented in <vllm.v1.worker.gpu_worker.Worker object at 0x78579a1d0950>
(VllmWorker rank=1 pid=42647) INFO 03-01 03:24:06 [shm_broadcast.py:258] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_7b4b9014'), local_subscribe_addr='ipc:///tmp/c981e423-630d-4330-b6a4-201ee387fccf', remote_subscribe_addr=None, remote_addr_ipv6=False)
(VllmWorker rank=1 pid=42647) INFO 03-01 03:24:07 [utils.py:939] Found nccl from library libnccl.so.2
(VllmWorker rank=0 pid=42605) INFO 03-01 03:24:07 [utils.py:939] Found nccl from library libnccl.so.2
(VllmWorker rank=1 pid=42647) INFO 03-01 03:24:07 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorker rank=0 pid=42605) INFO 03-01 03:24:07 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorker rank=0 pid=42605) INFO 03-01 03:24:07 [custom_all_reduce_utils.py:244] reading GPU P2P access cache from /root/.cache/vllm/gpu_p2p_access_cache_for_0,1.json
(VllmWorker rank=1 pid=42647) INFO 03-01 03:24:07 [custom_all_reduce_utils.py:244] reading GPU P2P access cache from /root/.cache/vllm/gpu_p2p_access_cache_for_0,1.json
(VllmWorker rank=0 pid=42605) INFO 03-01 03:24:07 [shm_broadcast.py:258] vLLM message queue communication handle: Handle(local_reader_ranks=[1], buffer_handle=(1, 4194304, 6, 'psm_4f618998'), local_subscribe_addr='ipc:///tmp/49b66037-1022-4554-aee6-b094b86ce383', remote_subscribe_addr=None, remote_addr_ipv6=False)
(VllmWorker rank=0 pid=42605) INFO 03-01 03:24:07 [parallel_state.py:948] rank 0 in world size 2 is assigned as DP rank 0, PP rank 0, TP rank 0
(VllmWorker rank=1 pid=42647) INFO 03-01 03:24:07 [parallel_state.py:948] rank 1 in world size 2 is assigned as DP rank 0, PP rank 0, TP rank 1
(VllmWorker rank=0 pid=42605) INFO 03-01 03:24:07 [cuda.py:202] Cannot use Flash Attention backend for Turing GPUs.
(VllmWorker rank=0 pid=42605) INFO 03-01 03:24:07 [cuda.py:204] Using Triton backend on V1 engine.
(VllmWorker rank=1 pid=42647) INFO 03-01 03:24:07 [cuda.py:202] Cannot use Flash Attention backend for Turing GPUs.
(VllmWorker rank=1 pid=42647) INFO 03-01 03:24:07 [cuda.py:204] Using Triton backend on V1 engine.
(VllmWorker rank=0 pid=42605) INFO 03-01 03:24:07 [gpu_model_runner.py:1054] Starting to load model Qwen/Qwen2.5-3B-Instruct...
(VllmWorker rank=1 pid=42647) INFO 03-01 03:24:07 [gpu_model_runner.py:1054] Starting to load model Qwen/Qwen2.5-3B-Instruct...
(VllmWorker rank=1 pid=42647) INFO 03-01 03:24:07 [cuda.py:202] Cannot use Flash Attention backend for Turing GPUs.
(VllmWorker rank=0 pid=42605) INFO 03-01 03:24:07 [cuda.py:202] Cannot use Flash Attention backend for Turing GPUs.
(VllmWorker rank=1 pid=42647) INFO 03-01 03:24:07 [cuda.py:204] Using Triton backend on V1 engine.
(VllmWorker rank=0 pid=42605) INFO 03-01 03:24:07 [cuda.py:204] Using Triton backend on V1 engine.
(VllmWorker rank=1 pid=42647) WARNING 03-01 03:24:07 [topk_topp_sampler.py:46] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer.
(VllmWorker rank=0 pid=42605) WARNING 03-01 03:24:07 [topk_topp_sampler.py:46] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer.
(VllmWorker rank=1 pid=42647) INFO 03-01 03:24:08 [weight_utils.py:254] Using model weights format ['*.safetensors']
(VllmWorker rank=0 pid=42605) INFO 03-01 03:24:08 [weight_utils.py:254] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:01<00:01,  2.00s/it]
(VllmWorker rank=1 pid=42647) INFO 03-01 03:24:12 [gpu_model_runner.py:1066] Loading model weights took 2.9348 GB and 4.151953 seconds
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:03<00:00,  1.45s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:03<00:00,  1.53s/it]
(VllmWorker rank=0 pid=42605) 
(VllmWorker rank=0 pid=42605) INFO 03-01 03:24:12 [gpu_model_runner.py:1066] Loading model weights took 2.9348 GB and 4.580717 seconds
INFO 03-01 03:24:18 [kv_cache_utils.py:524] GPU KV cache size: 559,584 tokens
INFO 03-01 03:24:18 [kv_cache_utils.py:527] Maximum concurrency for 32,768 tokens per request: 17.08x
INFO 03-01 03:24:18 [kv_cache_utils.py:524] GPU KV cache size: 559,584 tokens
INFO 03-01 03:24:18 [kv_cache_utils.py:527] Maximum concurrency for 32,768 tokens per request: 17.08x
INFO 03-01 03:24:18 [core.py:116] init engine (profile, create kv cache, warmup model) took 5.81 seconds
Processed prompts: 100%|███████████████████████████████████████████████████| 4/4 [00:06<00:00,  1.67s/it, est. speed input: 3.29 toks/s, output: 9.56 toks/s]
Prompt: 'Hello, my name is', Generated text: " Josh and I'm here to personally walk you through the process of enabling swap with"
Prompt: 'The president of the United States is', Generated text: ' a very important and respected position in the country. A pop quiz came up a'
Prompt: 'The capital of France is', Generated text: ' Paris across the Seine. Paris itself is divided into 20 districts,'
Prompt: 'The future of AI is', Generated text: ' heavily linked to quantum computing, but it will first need to become better at handling'

Copy link

github-actions bot commented Mar 1, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Mar 1, 2025
Signed-off-by: Isotr0py <[email protected]>
@DarkLight1337 DarkLight1337 added the rocm Related to AMD ROCm label Mar 1, 2025
Copy link

mergify bot commented Mar 1, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Isotr0py.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 1, 2025
@mergify mergify bot removed the needs-rebase label Mar 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
rocm Related to AMD ROCm v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants