Skip to content

enable finegrained_fp8 and granite_speech cases on XPU #38036

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/transformers/integrations/finegrained_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from typing import List, Optional, Tuple

from ..utils import is_accelerate_available, is_torch_available, logging
from ..utils import is_accelerate_available, is_torch_accelerator_available, is_torch_available, logging


if is_torch_available():
Expand Down Expand Up @@ -332,8 +332,10 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.weight.element_size() > 1:
return F.linear(input, self.weight, self.bias)
else:
# Context manager used to switch among the available cuda devices
with torch.cuda.device(input.device):
# Context manager used to switch among the available accelerators
device_type = torch.accelerator.current_accelerator().type if is_torch_accelerator_available() else "cuda"
torch_accelerator_module = getattr(torch, device_type, torch.cuda)
with torch_accelerator_module.device(input.device):
qinput, scale = act_quant(input, self.block_size[1])
output = w8a8_block_fp8_matmul_triton(
qinput,
Expand All @@ -343,9 +345,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
self.block_size,
output_dtype=input.dtype,
)
# Blocks the CPU until all CUDA operations on the specified device are complete. It is used to ensure that the results of the
# Blocks the CPU until all accelerator operations on the specified device are complete. It is used to ensure that the results of the
# preceding operations are ready before proceeding
torch.cuda.synchronize()
torch_accelerator_module.synchronize()
if self.bias is not None:
output = output + self.bias
return output.to(dtype=input.dtype)
Expand Down
23 changes: 12 additions & 11 deletions src/transformers/quantizers/quantizer_finegrained_fp8.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from ..utils import is_accelerate_available, is_torch_available, logging
from ..utils import is_accelerate_available, is_torch_available, is_torch_xpu_available, logging
from .base import HfQuantizer
from .quantizers_utils import get_module_from_name

Expand Down Expand Up @@ -44,16 +44,17 @@ def validate_environment(self, *args, **kwargs):
"please make sure the weights are in PyTorch format."
)

if not torch.cuda.is_available():
raise RuntimeError("No GPU found. A GPU is needed for FP8 quantization.")
if not (torch.cuda.is_available() or is_torch_xpu_available()):
raise RuntimeError("No GPU or XPU found. A GPU or XPU is needed for FP8 quantization.")

compute_capability = torch.cuda.get_device_capability()
major, minor = compute_capability
if (major < 8) or (major == 8 and minor < 9):
raise ValueError(
"FP8 quantized models is only supported on GPUs with compute capability >= 8.9 (e.g 4090/H100)"
f", actual = `{major}.{minor}`"
)
if torch.cuda.is_available():
compute_capability = torch.cuda.get_device_capability()
major, minor = compute_capability
if (major < 8) or (major == 8 and minor < 9):
raise ValueError(
"FP8 quantized models is only supported on GPUs with compute capability >= 8.9 (e.g 4090/H100)"
f", actual = `{major}.{minor}`"
)

device_map = kwargs.get("device_map", None)
if device_map is None:
Expand Down Expand Up @@ -217,7 +218,7 @@ def update_tp_plan(self, config):

config.base_model_tp_plan = text_plan

return config
return config

def is_serializable(self, safe_serialization=None):
return True
Expand Down
7 changes: 4 additions & 3 deletions tests/models/granite_speech/test_processor_granite_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
from transformers import AutoTokenizer, GPT2TokenizerFast
from transformers.testing_utils import (
require_torch,
require_torch_gpu,
require_torch_accelerator,
require_torchaudio,
torch_device,
)
from transformers.utils import is_torchaudio_available

Expand Down Expand Up @@ -195,7 +196,7 @@ def test_audio_token_filling_varying_len_feature_list(self):
assert num_calculated_features == [90, 171]
assert sum(num_expected_features) == num_audio_tokens

@require_torch_gpu
@require_torch_accelerator
def test_device_override(self):
"""Ensure that we regardless of the processing device, the tensors
produced are on the CPU.
Expand All @@ -214,7 +215,7 @@ def test_device_override(self):
text=f"{processor.audio_token} Can you transcribe this audio?",
audio=wav,
return_tensors="pt",
device="cuda",
device=torch_device,
)

assert inputs["input_features"].device.type == "cpu"
34 changes: 19 additions & 15 deletions tests/quantization/finegrained_fp8/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@

from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, FineGrainedFP8Config, OPTForCausalLM
from transformers.testing_utils import (
backend_empty_cache,
require_accelerate,
require_read_token,
require_torch_gpu,
require_torch_multi_gpu,
require_torch_accelerator,
require_torch_multi_accelerator,
slow,
torch_device,
)
from transformers.utils import is_accelerate_available, is_torch_available

Expand All @@ -34,7 +36,7 @@
from accelerate import init_empty_weights


@require_torch_gpu
@require_torch_accelerator
class FineGrainedFP8ConfigTest(unittest.TestCase):
def test_to_dict(self):
"""
Expand All @@ -60,13 +62,13 @@ def test_from_dict(self):
@slow
@require_accelerate
@require_read_token
@require_torch_gpu
@require_torch_accelerator
class FP8QuantizerTest(unittest.TestCase):
model_name = "meta-llama/Llama-3.2-1B"
input_text = "Once upon a time"
max_new_tokens = 10
EXPECTED_OUTPUT = "Once upon a time, there was a man who was very rich."
device_map = "cuda"
device_map = torch_device
offload_device_map = {
"model.embed_tokens": 0,
"model.layers.0": 0,
Expand Down Expand Up @@ -103,7 +105,7 @@ def setUpClass(cls):

def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
gc.collect()

def test_quantized_model_conversion(self):
Expand Down Expand Up @@ -151,7 +153,8 @@ def test_quantized_model(self):
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map)

output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
output_tokens = self.tokenizer.decode(output[0], skip_special_tokens=True)
self.assertEqual(output_tokens, self.EXPECTED_OUTPUT)

def test_save_pretrained(self):
"""
Expand Down Expand Up @@ -188,11 +191,12 @@ def test_block_size(self):
)
self.assertEqual(quantized_model.config.quantization_config.weight_block_size, (32, 32))

@require_torch_multi_gpu
def test_quantized_model_multi_gpu(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For multi_accelerator case ground truth. I can find both XPU and A100 return device 0, but the ground truth is {0,1}.
I look a bit deeper, for the target model meta-llama/Llama-3.2-1B, actually it has 2 modules_to_treat in accelerate infer_auto_device_map, as below. and the lm_head is tied with model, which means it only has 1 module to treat, and naturally can only placed to device 0.

I don't know whether there are any other scenarios I didn't considered, but for this case, seems the correct ground truth should be 0. @ydshieh , pls let me know your insights, thx

[('model', LlamaModel(
(embed_tokens): Embedding(128256, 2048)
(layers): ModuleList(
(0-15): 16 x LlamaDecoderLayer(
(self_attn): LlamaAttention(
(q_proj): FP8Linear(in_features=2048, out_features=2048, bias=False)
(k_proj): FP8Linear(in_features=2048, out_features=512, bias=False)
(v_proj): FP8Linear(in_features=2048, out_features=512, bias=False)
(o_proj): FP8Linear(in_features=2048, out_features=2048, bias=False)
)
(mlp): LlamaMLP(
(gate_proj): FP8Linear(in_features=2048, out_features=8192, bias=False)
(up_proj): FP8Linear(in_features=2048, out_features=8192, bias=False)
(down_proj): FP8Linear(in_features=8192, out_features=2048, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
(post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
)
)
(norm): LlamaRMSNorm((2048,), eps=1e-05)
(rotary_emb): LlamaRotaryEmbedding()
)), ('lm_head', Linear(in_features=2048, out_features=128256, bias=False))]

Copy link
Member

@IlyasMoutawwakil IlyasMoutawwakil May 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this depends on the vram of your GPUs/XPUs ; it will only use both if one is not enough, otherwise maybe it would make sense to use another device map strategy here lik "balanced"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will let @SunMarc or @MekkCyber to share their thoughts for this.

On our CI, these tests are not collected, I believe it is due to the require_read_token decorator at the class level.

@yao-matrix You are able to run this test ...? I am surprised. I will take a look at this issue

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

embed_tokens is indeed tied to lm_head but the layers can be dispatched to other gpus. setting "auto" in device_map will default to "balanced" strategy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will let @SunMarc or @MekkCyber to share their thoughts for this.

On our CI, these tests are not collected, I believe it is due to the require_read_token decorator at the class level.

@yao-matrix You are able to run this test ...? I am surprised. I will take a look at this issue

I removed require_read_token in my local env and run this case.

Copy link
Contributor Author

@yao-matrix yao-matrix May 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@IlyasMoutawwakil @SunMarc yes, i tried balanced in my local env too w/ the same consideration as yours(my XPU has 64GB VRAM), but the result is still 1. It seems split granularity is top-level module when available memory is enough to fit it, in infer_auto_device_map

Copy link
Member

@SunMarc SunMarc May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will investigate a bit more. @MekkCyber tested locally and it works but when running with pytest it fails

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SunMarc @MekkCyber You will need to remove require_read_token decorator in order to run these tests.
(if not done so yet)

related issue: #38093

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you think there is no more change required for this test_quantized_model_multi_gpu, feel free give a ✅ 🙏 .

From my side, I am just waiting a nit change regarding variable name.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's good to go from my side, we need to figure out why it fails with pytest but no need to include that in this pr, and thanks @ydshieh for the advice about require_read_token

@require_torch_multi_accelerator
def test_quantized_model_multi_accelerator(self):
"""
Simple test that checks if the quantized model is working properly with multiple GPUs
set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUs
Simple test that checks if the quantized model is working properly with multiple accelerators
set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUs; or set ZE_AFFINITY_MASK=0,1 if you
have more than 2 XPUs.
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map)
quantization_config = FineGrainedFP8Config()
Expand All @@ -204,8 +208,8 @@ def test_quantized_model_multi_gpu(self):
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)

@require_torch_multi_gpu
def test_save_pretrained_multi_gpu(self):
@require_torch_multi_accelerator
def test_save_pretrained_multi_accelerators(self):
"""
Simple test that checks if the quantized model is working properly after being saved and loaded
"""
Expand Down Expand Up @@ -245,9 +249,9 @@ def test_save_pretrained_offload(self):
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)


@require_torch_gpu
@require_torch_accelerator
class FP8LinearTest(unittest.TestCase):
device = "cuda"
device = torch_device

@unittest.skipIf(
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 9,
Expand Down