diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index 22d1e9e30ade..c259a7dbfaf9 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -15,7 +15,7 @@ from typing import List, Optional, Tuple -from ..utils import is_accelerate_available, is_torch_available, logging +from ..utils import is_accelerate_available, is_torch_accelerator_available, is_torch_available, logging if is_torch_available(): @@ -332,8 +332,10 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: if self.weight.element_size() > 1: return F.linear(input, self.weight, self.bias) else: - # Context manager used to switch among the available cuda devices - with torch.cuda.device(input.device): + # Context manager used to switch among the available accelerators + device_type = torch.accelerator.current_accelerator().type if is_torch_accelerator_available() else "cuda" + torch_accelerator_module = getattr(torch, device_type, torch.cuda) + with torch_accelerator_module.device(input.device): qinput, scale = act_quant(input, self.block_size[1]) output = w8a8_block_fp8_matmul_triton( qinput, @@ -343,9 +345,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: self.block_size, output_dtype=input.dtype, ) - # Blocks the CPU until all CUDA operations on the specified device are complete. It is used to ensure that the results of the + # Blocks the CPU until all accelerator operations on the specified device are complete. It is used to ensure that the results of the # preceding operations are ready before proceeding - torch.cuda.synchronize() + torch_accelerator_module.synchronize() if self.bias is not None: output = output + self.bias return output.to(dtype=input.dtype) diff --git a/src/transformers/quantizers/quantizer_finegrained_fp8.py b/src/transformers/quantizers/quantizer_finegrained_fp8.py index 880d573a3192..ed87b70bbc9b 100644 --- a/src/transformers/quantizers/quantizer_finegrained_fp8.py +++ b/src/transformers/quantizers/quantizer_finegrained_fp8.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional -from ..utils import is_accelerate_available, is_torch_available, logging +from ..utils import is_accelerate_available, is_torch_available, is_torch_xpu_available, logging from .base import HfQuantizer from .quantizers_utils import get_module_from_name @@ -44,16 +44,17 @@ def validate_environment(self, *args, **kwargs): "please make sure the weights are in PyTorch format." ) - if not torch.cuda.is_available(): - raise RuntimeError("No GPU found. A GPU is needed for FP8 quantization.") + if not (torch.cuda.is_available() or is_torch_xpu_available()): + raise RuntimeError("No GPU or XPU found. A GPU or XPU is needed for FP8 quantization.") - compute_capability = torch.cuda.get_device_capability() - major, minor = compute_capability - if (major < 8) or (major == 8 and minor < 9): - raise ValueError( - "FP8 quantized models is only supported on GPUs with compute capability >= 8.9 (e.g 4090/H100)" - f", actual = `{major}.{minor}`" - ) + if torch.cuda.is_available(): + compute_capability = torch.cuda.get_device_capability() + major, minor = compute_capability + if (major < 8) or (major == 8 and minor < 9): + raise ValueError( + "FP8 quantized models is only supported on GPUs with compute capability >= 8.9 (e.g 4090/H100)" + f", actual = `{major}.{minor}`" + ) device_map = kwargs.get("device_map", None) if device_map is None: @@ -217,7 +218,7 @@ def update_tp_plan(self, config): config.base_model_tp_plan = text_plan - return config + return config def is_serializable(self, safe_serialization=None): return True diff --git a/tests/models/granite_speech/test_processor_granite_speech.py b/tests/models/granite_speech/test_processor_granite_speech.py index 2f8825ce6e46..569ac9cfbc19 100644 --- a/tests/models/granite_speech/test_processor_granite_speech.py +++ b/tests/models/granite_speech/test_processor_granite_speech.py @@ -23,8 +23,9 @@ from transformers import AutoTokenizer, GPT2TokenizerFast from transformers.testing_utils import ( require_torch, - require_torch_gpu, + require_torch_accelerator, require_torchaudio, + torch_device, ) from transformers.utils import is_torchaudio_available @@ -195,7 +196,7 @@ def test_audio_token_filling_varying_len_feature_list(self): assert num_calculated_features == [90, 171] assert sum(num_expected_features) == num_audio_tokens - @require_torch_gpu + @require_torch_accelerator def test_device_override(self): """Ensure that we regardless of the processing device, the tensors produced are on the CPU. @@ -214,7 +215,7 @@ def test_device_override(self): text=f"{processor.audio_token} Can you transcribe this audio?", audio=wav, return_tensors="pt", - device="cuda", + device=torch_device, ) assert inputs["input_features"].device.type == "cpu" diff --git a/tests/quantization/finegrained_fp8/test_fp8.py b/tests/quantization/finegrained_fp8/test_fp8.py index 83bdea1e6635..b5a586b0302f 100644 --- a/tests/quantization/finegrained_fp8/test_fp8.py +++ b/tests/quantization/finegrained_fp8/test_fp8.py @@ -18,11 +18,13 @@ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, FineGrainedFP8Config, OPTForCausalLM from transformers.testing_utils import ( + backend_empty_cache, require_accelerate, require_read_token, - require_torch_gpu, - require_torch_multi_gpu, + require_torch_accelerator, + require_torch_multi_accelerator, slow, + torch_device, ) from transformers.utils import is_accelerate_available, is_torch_available @@ -34,7 +36,7 @@ from accelerate import init_empty_weights -@require_torch_gpu +@require_torch_accelerator class FineGrainedFP8ConfigTest(unittest.TestCase): def test_to_dict(self): """ @@ -60,13 +62,13 @@ def test_from_dict(self): @slow @require_accelerate @require_read_token -@require_torch_gpu +@require_torch_accelerator class FP8QuantizerTest(unittest.TestCase): model_name = "meta-llama/Llama-3.2-1B" input_text = "Once upon a time" max_new_tokens = 10 EXPECTED_OUTPUT = "Once upon a time, there was a man who was very rich." - device_map = "cuda" + device_map = torch_device offload_device_map = { "model.embed_tokens": 0, "model.layers.0": 0, @@ -103,7 +105,7 @@ def setUpClass(cls): def tearDown(self): gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) gc.collect() def test_quantized_model_conversion(self): @@ -151,7 +153,8 @@ def test_quantized_model(self): input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map) output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False) - self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + output_tokens = self.tokenizer.decode(output[0], skip_special_tokens=True) + self.assertEqual(output_tokens, self.EXPECTED_OUTPUT) def test_save_pretrained(self): """ @@ -188,11 +191,12 @@ def test_block_size(self): ) self.assertEqual(quantized_model.config.quantization_config.weight_block_size, (32, 32)) - @require_torch_multi_gpu - def test_quantized_model_multi_gpu(self): + @require_torch_multi_accelerator + def test_quantized_model_multi_accelerator(self): """ - Simple test that checks if the quantized model is working properly with multiple GPUs - set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUs + Simple test that checks if the quantized model is working properly with multiple accelerators + set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUs; or set ZE_AFFINITY_MASK=0,1 if you + have more than 2 XPUs. """ input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map) quantization_config = FineGrainedFP8Config() @@ -204,8 +208,8 @@ def test_quantized_model_multi_gpu(self): output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False) self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) - @require_torch_multi_gpu - def test_save_pretrained_multi_gpu(self): + @require_torch_multi_accelerator + def test_save_pretrained_multi_accelerators(self): """ Simple test that checks if the quantized model is working properly after being saved and loaded """ @@ -245,9 +249,9 @@ def test_save_pretrained_offload(self): self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) -@require_torch_gpu +@require_torch_accelerator class FP8LinearTest(unittest.TestCase): - device = "cuda" + device = torch_device @unittest.skipIf( torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 9,