From 967ea76517fb540a617e3b8c555f962c7aa2c6b0 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Tue, 1 Apr 2025 13:29:09 -0700 Subject: [PATCH 01/30] Add profiler and Perfetto UI link with comprehensive tests (#1984, #1992) ghstack-source-id: a5f8301acb77a180a395aa8dd4c1aa9c2ccd7522 ghstack-comment-id: 2770609971 Pull Request resolved: https://github.com/pytorch/ao/pull/1997 --- .../microbenchmarks/benchmark_inference.py | 154 ++++---- .../microbenchmarks/benchmark_runner.py | 22 +- .../microbenchmarks/test/benchmark_config.yml | 61 ++-- .../test/test_benchmark_profiler.py | 224 ++++++++++++ benchmarks/microbenchmarks/utils.py | 335 +++++++++++++++--- 5 files changed, 638 insertions(+), 158 deletions(-) create mode 100644 benchmarks/microbenchmarks/test/test_benchmark_profiler.py diff --git a/benchmarks/microbenchmarks/benchmark_inference.py b/benchmarks/microbenchmarks/benchmark_inference.py index c084d18d3a..15d62d1386 100644 --- a/benchmarks/microbenchmarks/benchmark_inference.py +++ b/benchmarks/microbenchmarks/benchmark_inference.py @@ -20,6 +20,8 @@ BenchmarkResult, clean_caches, create_model_and_input, + generate_memory_profile, + generate_model_profile, model_inference_time_in_ms, string_to_config, ) @@ -29,70 +31,92 @@ def run(config: BenchmarkConfig) -> BenchmarkResult: """Run inference benchmarks""" - clean_caches() # Clean caches - - # Create output directory if it doesn't exist - Path(config.output_dir).mkdir(parents=True, exist_ok=True) - - base_model, input_data = create_model_and_input( - config.model_type, - config.m, - config.k, - config.n, - high_precision_dtype=config.high_precision_dtype, - device=config.device, - ) - - # Use quantize_ to apply each quantization function to the model - m_copy = deepcopy(base_model).eval().to(config.device) - ao_base_config = string_to_config( - config.quantization, - config.sparsity, - high_precision_dtype=config.high_precision_dtype, - ) - - # Check if sparsity is requested and if the device is CUDA (sparsity operations require CUDA) - is_cuda = config.device == "cuda" and torch.cuda.is_available() - - if config.sparsity is not None and ( - config.quantization is None or "baseline" in config.quantization - ): - if is_cuda: - print(f"Applying {config.sparsity} sparsity to model") - sparsify_(m_copy, ao_base_config) + try: + clean_caches() # Clean caches + + # Create output directory if it doesn't exist + Path(config.output_dir).mkdir(parents=True, exist_ok=True) + + base_model, input_data = create_model_and_input( + config.model_type, + config.m, + config.k, + config.n, + high_precision_dtype=config.high_precision_dtype, + device=config.device, + ) + + # Use quantize_ to apply each quantization function to the model + m_copy = deepcopy(base_model).eval().to(config.device) + ao_base_config = string_to_config( + config.quantization, + config.sparsity, + high_precision_dtype=config.high_precision_dtype, + ) + + # Check if sparsity is requested and if the device is CUDA (sparsity operations require CUDA) + is_cuda = config.device == "cuda" and torch.cuda.is_available() + + if config.sparsity is not None and ( + config.quantization is None or "baseline" in config.quantization + ): + if is_cuda: + print(f"Applying {config.sparsity} sparsity to model") + sparsify_(m_copy, ao_base_config) + else: + print( + f"Warning: Skipping {config.sparsity} sparsity as it requires CUDA, but device is {config.device}" + ) + elif config.sparsity is None and ( + config.quantization is None or "baseline" in config.quantization + ): + pass # No quantization or sparsity specified, do nothing else: - print( - f"Warning: Skipping {config.sparsity} sparsity as it requires CUDA, but device is {config.device}" + print("Quantizing model....") + quantize_(m_copy, ao_base_config) + + if config.use_torch_compile: + print("Compiling model....") + m_copy = torch.compile( + m_copy, mode=config.torch_compile_mode, fullgraph=True ) - elif config.sparsity is None and ( - config.quantization is None or "baseline" in config.quantization - ): - pass # No quantization or sparsity specified, do nothing - else: - print("Quantizing model....") - quantize_(m_copy, ao_base_config) - - if config.use_torch_compile: - print("Compiling model....") - m_copy = torch.compile(m_copy, mode=config.torch_compile_mode, fullgraph=True) - - # Run benchmarks - result = BenchmarkResult(config=config) - - # Benchmark time to run an inference call for quantized model - result.model_inference_time_in_ms = model_inference_time_in_ms( - model=m_copy, input_data=input_data - ) - - # TODO: Benchmark time using profiler - # Profile dtype model evaluation - # prof_dtype = benchmark_model_op_with_profiler_in_microseconds(m_copy, input_data, quantized_dtype) - # prof_dtype.export_chrome_trace(f"{quantization}_model_{input_data[0].size()[0]}.json") # Save profiling details - - # TODO: Benchmark gemm time using cuda graph - # gemm_time = benchmark_torch_function_in_microseconds(gemm_op, *args, **kwargs) - - # TODO: Benchmark op with cuda graph - # time = benchmark_op_with_cuda_graph(op, args) - - return result + + # Run benchmarks + result = BenchmarkResult(config=config) + # Store result in model for memory profiling + m_copy._benchmark_result = result + + # Benchmark time to run an inference call for quantized model + result.model_inference_time_in_ms = model_inference_time_in_ms( + model=m_copy, input_data=input_data + ) + + # Run profiler if enabled + if config.enable_profiler: + print("Running profiler...") + try: + result.profiler_json_path, result.perfetto_url = generate_model_profile( + m_copy, input_data, config.profiler_file_name + ) + except Exception as e: + print(f"Error running profiler: {e}") + + # Run memory profiler if enabled + if config.enable_memory_profile: + print("Running memory profiler...") + try: + result.memory_profile_path, result.memory_stats = ( + generate_memory_profile( + m_copy, input_data, config.memory_profile_file_name + ) + ) + except Exception as e: + print(f"Error running memory profiler: {e}") + + return result + except Exception as e: + print(f"Error in benchmark run: {e}") + import traceback + + print(traceback.format_exc()) + return None diff --git a/benchmarks/microbenchmarks/benchmark_runner.py b/benchmarks/microbenchmarks/benchmark_runner.py index 7152542eec..1a60ca6b16 100644 --- a/benchmarks/microbenchmarks/benchmark_runner.py +++ b/benchmarks/microbenchmarks/benchmark_runner.py @@ -164,16 +164,22 @@ def run_inference_benchmarks_from_config(configs: List[BenchmarkConfig]) -> None f"Running: {config.name} for Quantization: {config.quantization} and Sparsity: {config.sparsity}" ) result = run_inference(config) # Pass the config object directly - results.append(result) - except Exception: - print(f"Error running benchmark {config.name}") - continue + if result is not None: # Only add successful results + results.append(result) + except Exception as e: + import traceback - # Add results to csv - generate_results_csv(results, configs[0].output_dir) + print(f"Error running benchmark {config.name} with error: {e}") + print(traceback.format_exc()) + continue - # Print results - print_results(results) + # Add results to csv if there are any + if results: + generate_results_csv(results, configs[0].output_dir) + # Print results + print_results(results) + else: + print("No benchmark results were collected. All benchmarks failed.") # TODO: Process results: Speedups: # 1. For different shapes for same model and quantization diff --git a/benchmarks/microbenchmarks/test/benchmark_config.yml b/benchmarks/microbenchmarks/test/benchmark_config.yml index 97a38469de..227cb90948 100644 --- a/benchmarks/microbenchmarks/test/benchmark_config.yml +++ b/benchmarks/microbenchmarks/test/benchmark_config.yml @@ -2,46 +2,51 @@ benchmark_mode: "inference" quantization_config_recipe_names: # Will run a baseline inference for model by default, without quantization for comparison - - "int4wo-32" - - "marlin" -sparsity_config_recipe_names: + # - "int4wo-32" + # - "marlin" + - "int8wo" +# sparsity_config_recipe_names: # Will run a baseline inference for model by default, without sparsity for comparison - - "semi-sparse" - - "block" + # - "semi-sparse" + # - "block" output_dir: "benchmarks/microbenchmarks/results" model_params: - - name: "small_bf16_linear" - matrix_shapes: - - name: "custom" - shapes: [ - [1024, 1024, 1024], # [m, k, n] - ] - high_precision_dtype: "torch.bfloat16" - use_torch_compile: true - torch_compile_mode: "max-autotune" - device: "cuda" - model_type: "linear" + # - name: "small_bf16_linear" + # matrix_shapes: + # - name: "custom" + # shapes: [ + # [1024, 1024, 1024], # [m, k, n] + # ] + # high_precision_dtype: "torch.bfloat16" + # use_torch_compile: true + # torch_compile_mode: "max-autotune" + # device: "cuda" + # model_type: "linear" + # enable_profiler: true # Enable profiling for this model - name: "large_bf16_ln_linear" matrix_shapes: - name: "custom" shapes: [ [2048, 4096, 1024], - [4096, 4096, 1024] + # [4096, 4096, 1024] ] high_precision_dtype: "torch.bfloat16" use_torch_compile: true torch_compile_mode: "max-autotune" device: "cuda" - model_type: "ln_linear_sigmoid" - - - name: "cpu_fp32_linear" - matrix_shapes: - - name: "custom" - shapes: [ - [4096, 4096, 1024] - ] - high_precision_dtype: "torch.float32" - use_torch_compile: false - device: "cpu" model_type: "linear" + enable_profiler: true # Enable profiling for this model + enable_memory_profile: true # Enable memory profiling for this model + + # - name: "cpu_fp32_linear" + # matrix_shapes: + # - name: "custom" + # shapes: [ + # [4096, 4096, 1024] + # ] + # high_precision_dtype: "torch.float32" + # use_torch_compile: false + # device: "cpu" + # model_type: "linear" + # enable_profiler: true # Enable profiling for this model diff --git a/benchmarks/microbenchmarks/test/test_benchmark_profiler.py b/benchmarks/microbenchmarks/test/test_benchmark_profiler.py new file mode 100644 index 0000000000..42ad8af895 --- /dev/null +++ b/benchmarks/microbenchmarks/test/test_benchmark_profiler.py @@ -0,0 +1,224 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os +import unittest + +import torch + +from benchmarks.microbenchmarks.utils import ( + BenchmarkConfig, + ToyLinearModel, + generate_memory_profile, + generate_model_profile, +) + + +class TestBenchmarkProfiler(unittest.TestCase): + def setUp(self): + self.test_dir = os.path.dirname(os.path.abspath(__file__)) + self.results_dir = os.path.join(self.test_dir, "results") + os.makedirs(self.results_dir, exist_ok=True) + + # Set up a simple model and input for testing + self.m, self.k, self.n = 1024, 1024, 1024 + self.dtype = torch.bfloat16 + self.model = ToyLinearModel(k=self.k, n=self.n, dtype=self.dtype) + self.input_data = torch.randn(1, self.k, dtype=self.dtype) + + # Move to appropriate device + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.model = self.model.to(self.device) + self.input_data = self.input_data.to(self.device) + + def tearDown(self): + # Clean up any generated files + import shutil + + if os.path.exists(self.results_dir): + shutil.rmtree(self.results_dir) + + def test_profiler_enabled(self): + """Test that profiler works when enabled""" + config = BenchmarkConfig( + quantization=None, + sparsity=None, + params={ + "enable_profiler": True, + "device": self.device, + }, + shape_name="test", + shape=[self.m, self.k, self.n], + output_dir=self.results_dir, + benchmark_mode="inference", + ) + + profile_path = os.path.join( + self.results_dir, + "profiler", + f"{config.name}_{self.m}_{self.k}_{self.n}_profile.json", + ) + + # Generate profile + result_path, _ = generate_model_profile( + self.model, self.input_data, profile_path + ) + + # Check that profile file exists and is not empty + self.assertTrue(os.path.exists(result_path)) + self.assertGreater(os.path.getsize(result_path), 0) + + # Verify it's valid JSON + with open(result_path) as f: + profile_data = json.load(f) + self.assertIsInstance(profile_data, dict) + + def test_profiler_basic_output(self): + """Test that profiler output contains expected basic fields""" + config = BenchmarkConfig( + quantization=None, + sparsity=None, + params={ + "enable_profiler": True, + "device": self.device, + }, + shape_name="test", + shape=[self.m, self.k, self.n], + output_dir=self.results_dir, + benchmark_mode="inference", + ) + + profile_path = os.path.join( + self.results_dir, + "profiler", + f"{config.name}_{self.m}_{self.k}_{self.n}_profile.json", + ) + + result_path, _ = generate_model_profile( + self.model, self.input_data, profile_path + ) + + with open(result_path) as f: + data = json.load(f) + + # Check for required Chrome Trace Event format fields + self.assertIn("traceEvents", data) + self.assertTrue(isinstance(data["traceEvents"], list)) + + # Check that we have some events + self.assertGreater(len(data["traceEvents"]), 0) + + # Check event format + event = data["traceEvents"][0] + self.assertIn("name", event) + self.assertIn("ph", event) # Phase + self.assertIn("ts", event) # Timestamp + self.assertIn("pid", event) # Process ID + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_cuda_profiling(self): + """Test CUDA profiling when available""" + config = BenchmarkConfig( + quantization=None, + sparsity=None, + params={ + "enable_profiler": True, + "device": "cuda", + }, + shape_name="test", + shape=[self.m, self.k, self.n], + output_dir=self.results_dir, + benchmark_mode="inference", + ) + + profile_path = os.path.join( + self.results_dir, + "profiler", + f"{config.name}_{self.m}_{self.k}_{self.n}_profile.json", + ) + + result_path, _ = generate_model_profile( + self.model.cuda(), self.input_data.cuda(), profile_path + ) + + with open(result_path) as f: + data = json.load(f) + + # Check for CUDA events + cuda_events = [ + event for event in data["traceEvents"] if "CUDA" in event.get("name", "") + ] + self.assertGreater(len(cuda_events), 0) + + def test_memory_profiling(self): + """Test memory profiling functionality for CUDA""" + if not torch.cuda.is_available(): + self.skipTest("CUDA not available") + + config = BenchmarkConfig( + quantization=None, + sparsity=None, + params={ + "enable_memory_profile": True, + "device": "cuda", + }, + shape_name="test", + shape=[self.m, self.k, self.n], + output_dir=self.results_dir, + benchmark_mode="inference", + ) + + profile_path = os.path.join( + self.results_dir, + "profiler", + f"{config.name}_{self.m}_{self.k}_{self.n}_memory_profile.json", + ) + + result_path, memory_stats = generate_memory_profile( + self.model.cuda(), self.input_data.cuda(), profile_path + ) + + # Check that JSON profile file exists and is not empty + self.assertTrue(os.path.exists(result_path)) + self.assertGreater(os.path.getsize(result_path), 0) + + # Check that pickle profile file exists and is not empty + pickle_path = result_path.replace(".json", ".pickle") + self.assertTrue(os.path.exists(pickle_path)) + self.assertGreater(os.path.getsize(pickle_path), 0) + + # Verify memory stats structure + self.assertIn("peak_memory_allocated", memory_stats) + self.assertIn("peak_memory_reserved", memory_stats) + self.assertIn("total_memory_allocated", memory_stats) + self.assertIn("total_memory_reserved", memory_stats) + + # Verify memory values are reasonable + self.assertGreaterEqual(memory_stats["peak_memory_allocated"], 0) + self.assertGreaterEqual(memory_stats["peak_memory_reserved"], 0) + self.assertGreaterEqual(memory_stats["total_memory_allocated"], 0) + self.assertGreaterEqual(memory_stats["total_memory_reserved"], 0) + + # Verify pickle file can be loaded + with open(pickle_path, "rb") as f: + from pickle import load + + snapshot = load(f) + self.assertIsNotNone(snapshot) + + # Check that HTML visualization was generated + html_path = pickle_path.replace(".pickle", ".html") + if os.path.exists( + os.path.dirname(os.path.dirname(torch.__file__)) + + "/torch/cuda/_memory_viz.py" + ): + self.assertTrue(os.path.exists(html_path)) + self.assertGreater(os.path.getsize(html_path), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index fd3db11591..677f66ac75 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -4,11 +4,15 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import csv +import json import os +import subprocess +import uuid from typing import Any, Dict, List, Optional import torch from tabulate import tabulate +from torch.profiler import ProfilerActivity from torch.utils.benchmark import Timer from torchao.core.config import AOBaseConfig @@ -50,6 +54,211 @@ def get_default_device(device: str = "cuda") -> str: return "cpu" +def upload_trace_file(local_path: str, overwrite: bool = False) -> Optional[str]: + MANIFOLD_FOLDER = "perfetto_internal_traces/tree/shared_trace" + DEFAULT_TTL_SEC = 28 * 24 * 60 * 60 + file_name = os.path.basename(local_path) + manifold_path = os.path.join( + MANIFOLD_FOLDER, f"{os.getlogin()}_{str(uuid.uuid4())}_{file_name}" + ) + cmd = [ + "manifold", + "put", + local_path, + manifold_path, + "--ttl", + str(DEFAULT_TTL_SEC), + "--userData", + "false", + ] + ret = subprocess.run( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True + ) + if ret.returncode == 0: + print("Upload trace successfully.") + return manifold_path + else: + print("[ERROR] Upload failed, maybe the trace file exists.") + return None + + +def print_perfetto_ui_url(manifold_path: str) -> Optional[str]: + """Generate and print the Perfetto UI URL for a Manifold trace file. + + Args: + manifold_path: Path to the trace file in Manifold + + Returns: + The URL to the Perfetto UI or None if there was an error + """ + try: + url = ( + "https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/" + + manifold_path + ) + print(f"The trace is accessible at:\n{url}") + return url + except Exception as e: + print(f"Error generating Perfetto UI URL: {e}") + return None + + +def generate_model_profile(model, input_data, profile_file_path): + """Function to benchmark model evaluation with profiling. + + Args: + model: The model to profile + input_data: Input data for the model + profile_file_path: Path to save the profiler output + + Returns: + Tuple of (profile_file_path, perfetto_url) + """ + # Create parent directory if it doesn't exist + os.makedirs(os.path.dirname(profile_file_path), exist_ok=True) + + # Set up profiler activities based on device + activities = [ProfilerActivity.CPU] + device = next(model.parameters()).device + if device.type == "cuda" and torch.cuda.is_available(): + activities.append(ProfilerActivity.CUDA) + + # Run profiler with minimal settings to ensure compatibility + prof = torch.profiler.profile( + activities=activities, + record_shapes=True, + with_stack=False, # Disable stack traces to reduce overhead + profile_memory=False, # Disable memory profiling as it's not reliable across all devices + ) + + # Warm up + with torch.no_grad(): + for _ in range(3): + _ = model(input_data) + if device.type == "cuda": + torch.cuda.synchronize() + + # Profile + with prof: + with torch.no_grad(): + for _ in range(3): + _ = model(input_data) + if device.type == "cuda": + torch.cuda.synchronize() + + # Save profiling details + prof.export_chrome_trace(profile_file_path) + print(f"Profile saved to: {profile_file_path}") + + # Try to upload to Perfetto UI + perfetto_url = None + try: + manifold_path = upload_trace_file(profile_file_path) + if manifold_path: + perfetto_url = print_perfetto_ui_url(manifold_path) + except Exception as e: + print(f"Warning: Failed to upload profile to Perfetto UI: {e}") + + return profile_file_path, perfetto_url + + +# def visualize_memory_profile(snapshot, output_html_path) -> Optional[str]: +# from torch.cuda._memory_viz import trace_plot + +# # Convert to HTML +# html = trace_plot(snapshot) + +# # Save to file +# with open(output_html_path, "w") as f: +# f.write(html) + + +def generate_memory_profile(model, input_data, profile_file_path): + """Function to generate memory profile for model evaluation. + + Args: + model: The model to profile + input_data: Input data for the model + profile_file_path: Path to save the memory profile output + + Returns: + Tuple of (profile_file_path, memory_stats) + """ + # Create parent directory if it doesn't exist + os.makedirs(os.path.dirname(profile_file_path), exist_ok=True) + + device = next(model.parameters()).device + memory_stats = { + "peak_memory_allocated": 0, + "peak_memory_reserved": 0, + "total_memory_allocated": 0, + "total_memory_reserved": 0, + "memory_events": [], + } + + if device.type == "cuda": + # Enable memory history recording for CUDA + torch.cuda.memory._record_memory_history( + True, trace_alloc_max_entries=250000, trace_alloc_record_context=True + ) + + # Reset CUDA memory stats + torch.cuda.reset_peak_memory_stats() + torch.cuda.empty_cache() + + # Warm up + with torch.no_grad(): + for _ in range(3): + _ = model(input_data) + torch.cuda.synchronize() + + # Profile memory + with torch.no_grad(): + _ = model(input_data) + torch.cuda.synchronize() + + # Collect memory stats + memory_stats.update( + { + "peak_memory_allocated": torch.cuda.max_memory_allocated() + / 1024**2, # Convert to MB + "peak_memory_reserved": torch.cuda.max_memory_reserved() / 1024**2, + "total_memory_allocated": torch.cuda.memory_allocated() / 1024**2, + "total_memory_reserved": torch.cuda.memory_reserved() / 1024**2, + } + ) + + # Get detailed memory snapshot + snapshot = torch.cuda.memory._snapshot() + + # Save memory profile as pickle file + pickle_path = profile_file_path.replace(".json", ".pickle") + with open(pickle_path, "wb") as f: + from pickle import dump + + dump(snapshot, f) + + print(f"Memory profile saved to: {pickle_path}") + + # TODO: Add memory visualization + # visualize_memory_profile(snapshot, pickle_path.replace(".pickle", ".html")) + # print(f"Memory visualization saved to: {pickle_path.replace('.pickle', '.html')}") + + # Disable memory history recording + torch.cuda.memory._record_memory_history(False) + + else: + print("Memory profiling only works on CUDA devices") + # TODO: Add XPU support when available + return profile_file_path, memory_stats + + # Save basic stats as JSON for easy access + with open(profile_file_path, "w") as f: + json.dump(memory_stats, f, indent=2) + + return profile_file_path, memory_stats + + class BenchmarkConfig: def __init__( self, @@ -84,6 +293,18 @@ def __init__( "name", f"benchmark_{self.quantization}_{self.model_type}_m{self.m}_k{self.k}_n{self.n}{'_compile' if self.use_torch_compile else ''}", ) + self.enable_profiler = bool(params.get("enable_profiler", False)) + self.enable_memory_profile = bool(params.get("enable_memory_profile", False)) + # Create profiler directory path without leading slash + profiler_dir = os.path.join(self.output_dir, "profiler") + os.makedirs(profiler_dir, exist_ok=True) + file_name = f"{self.name}_{self.m}_{self.k}_{self.n}_quant_{self.quantization}_sparsity_{self.sparsity}" + self.profiler_file_name = os.path.join( + profiler_dir, f"{file_name}_profile.json" + ) + self.memory_profile_file_name = os.path.join( + profiler_dir, f"{file_name}_memory_profile.json" + ) @staticmethod def _parse_precision(precision_str: str) -> torch.dtype: @@ -105,6 +326,8 @@ def to_dict(self) -> Dict[str, Any]: "device": self.device, "model_type": self.model_type, "output_dir": self.output_dir, + "enable_profiler": self.enable_profiler, + "enable_memory_profile": self.enable_memory_profile, } @@ -116,13 +339,24 @@ def __init__( self.config = config self.output_dir = config.output_dir self.model_inference_time_in_ms = 0.0 + self.profiler_json_path: Optional[str] = None + self.perfetto_url: Optional[str] = None + self.memory_profile_path: Optional[str] = None + self.memory_stats: Optional[Dict[str, Any]] = None + # self.memory_visualization_path: Optional[str] = None def to_dict(self) -> Dict[str, Any]: """Convert result to dictionary for main function""" - return { + result_dict = { **self.config.to_dict(), "model_inference_time_in_ms": self.model_inference_time_in_ms, + "profiler_json_path": self.profiler_json_path, + "perfetto_url": self.perfetto_url, + "memory_profile_path": self.memory_profile_path, + "memory_stats": self.memory_stats, + # "memory_visualization_path": self.memory_visualization_path, } + return result_dict class ToyLinearModel(torch.nn.Module): @@ -373,6 +607,11 @@ def generate_results_csv( output_dir (str): Directory to save the CSV file. file_name (str, optional): Name of the CSV file. Defaults to "results.csv". """ + # Check if results list is empty + if not results: + print("No results to save to CSV.") + return + # Create the output directory if it doesn't exist os.makedirs(output_dir, exist_ok=True) file_path = os.path.join(output_dir, file_name) @@ -390,68 +629,50 @@ def generate_results_csv( def print_results(results: List[BenchmarkResult]): - """Print benchmark results in a formatted table. - - Args: - results (List[BenchmarkResult]): List of benchmark results - """ + """Print results in a table format""" if not results: print("No results to display") return - # Extract relevant columns for display - display_columns = [ - "quantization", - "sparsity", - "model_type", - "m", - "k", - "n", - "model_inference_time_in_ms", - "use_torch_compile", - ] - - # Format data for tabulate - headers = { - "quantization": "Quantization", - "sparsity": "Sparsity", - "model_type": "Model Type", - "m": "M", - "k": "K", - "n": "N", - "model_inference_time_in_ms": "Time (μs)", - "use_torch_compile": "Compile Mode", - } - - # Extract and format data table_data = [] for result in results: - result_dict = result.to_dict() - row = [] - for col in display_columns: - value = result_dict.get(col, "N/A") - if value is None: - value = "N/A" - if col == "model_inference_time_in_ms": - value = f"{value:.2f}" if isinstance(value, (int, float)) else value - elif col == "use_torch_compile": - # Show compile mode if compile is True, otherwise show False - value = ( - result_dict.get("torch_compile_mode", "default") - if result_dict.get("use_torch_compile") - else "False" + if result is None: + continue + + row = [ + result.config.name, + result.config.quantization or "baseline", + result.config.sparsity or "none", + f"{result.model_inference_time_in_ms:.2f}", + str(result.config.enable_profiler), + str(result.config.enable_memory_profile), + ] + + # Add memory profile data if enabled + if result.config.enable_memory_profile: + if result.memory_stats: + row.append( + f"Peak memory: {result.memory_stats['peak_memory_allocated']:.2f}MB" ) - row.append(value) + else: + row.append("Memory profiling failed") + table_data.append(row) - # Print formatted table - print("\nBenchmark Results:") - print( - tabulate( - table_data, - headers=[headers[col] for col in display_columns], - tablefmt="grid", - floatfmt=".2f", - ) - ) - print() + # Define headers + headers = [ + "Name", + "Quantization", + "Sparsity", + "Inference Time (ms)", + "Profiler Enabled", + "Memory Profiling Enabled", + ] + if any(r.config.enable_memory_profile for r in results if r is not None): + headers.append("Memory Profile Data") + + if table_data: + print("\nBenchmark Results:") + print(tabulate(table_data, headers=headers, tablefmt="grid")) + else: + print("\nNo valid results to display") From dd9f50d1b8b899a204cbe63d622b49ad2a4a954b Mon Sep 17 00:00:00 2001 From: jainapurva Date: Fri, 4 Apr 2025 00:17:35 -0700 Subject: [PATCH 02/30] More models --- .../microbenchmarks/test/benchmark_config.yml | 28 +++++ benchmarks/microbenchmarks/test/test_utils.py | 115 ++++++++++++++++++ benchmarks/microbenchmarks/utils.py | 110 +++++++++++++++++ 3 files changed, 253 insertions(+) diff --git a/benchmarks/microbenchmarks/test/benchmark_config.yml b/benchmarks/microbenchmarks/test/benchmark_config.yml index 227cb90948..4394d0208b 100644 --- a/benchmarks/microbenchmarks/test/benchmark_config.yml +++ b/benchmarks/microbenchmarks/test/benchmark_config.yml @@ -50,3 +50,31 @@ model_params: # device: "cpu" # model_type: "linear" # enable_profiler: true # Enable profiling for this model + + - name: "bf16_rms_norm_linear_activation" + matrix_shapes: + - name: "custom" + shapes: [ + [2048, 4096, 1024], + ] + high_precision_dtype: "torch.bfloat16" + use_torch_compile: true + torch_compile_mode: "max-autotune" + device: "cuda" + model_type: "rms_norm_linear_activation" + enable_profiler: true + enable_memory_profile: true + + - name: "bf16_transformer_block" + matrix_shapes: + - name: "custom" + shapes: [ + [2048, 4096, 1024], # For transformer_block, k is the hidden dimension + ] + high_precision_dtype: "torch.bfloat16" + use_torch_compile: true + torch_compile_mode: "max-autotune" + device: "cuda" + model_type: "transformer_block" + enable_profiler: true + enable_memory_profile: true diff --git a/benchmarks/microbenchmarks/test/test_utils.py b/benchmarks/microbenchmarks/test/test_utils.py index 14f226bd7e..46f6a74685 100644 --- a/benchmarks/microbenchmarks/test/test_utils.py +++ b/benchmarks/microbenchmarks/test/test_utils.py @@ -17,8 +17,11 @@ Float8DynamicActivationFloat8SemiSparseWeightConfig, Int4WeightOnlyConfig, LNLinearSigmoid, + RMSNorm, + RMSNormLinearActivation, SemiSparseWeightConfig, ToyLinearModel, + TransformerBlock, clean_caches, create_model_and_input, generate_results_csv, @@ -162,6 +165,61 @@ def test_ln_linear_sigmoid(self): torch.all((out >= 0) & (out <= 1)) ) # Check sigmoid output range + def test_rms_norm(self): + # Test RMSNorm + rms_norm = RMSNorm(dim=64) + x = torch.randn(16, 64) + out = rms_norm(x) + self.assertEqual(out.shape, (16, 64)) + + # Test with different eps + rms_norm = RMSNorm(dim=64, eps=1e-5) + out = rms_norm(x) + self.assertEqual(out.shape, (16, 64)) + + def test_rms_norm_linear_activation(self): + # Test with default GELU activation + model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32) + x = torch.randn(16, 64) + out = model(x) + self.assertEqual(out.shape, (16, 32)) + self.assertEqual(out.dtype, torch.float32) + + # Test with ReLU activation + model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="relu") + out = model(x) + self.assertEqual(out.shape, (16, 32)) + self.assertTrue(torch.all(out >= 0)) # Check ReLU output range + + # Test with SiLU activation + model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="silu") + out = model(x) + self.assertEqual(out.shape, (16, 32)) + + # Test with invalid activation + with self.assertRaises(ValueError): + RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="invalid") + + def test_transformer_block(self): + # Test with default parameters + model = TransformerBlock(hidden_dim=64, num_heads=8, mlp_ratio=4, dtype=torch.float32) + x = torch.randn(16, 16, 64) # [batch_size, seq_len, hidden_dim] + out = model(x) + self.assertEqual(out.shape, (16, 16, 64)) + self.assertEqual(out.dtype, torch.float32) + + # Test with different parameters + model = TransformerBlock(hidden_dim=128, num_heads=4, mlp_ratio=2, dtype=torch.float32) + x = torch.randn(8, 32, 128) + out = model(x) + self.assertEqual(out.shape, (8, 32, 128)) + + # Test with different head dimensions + model = TransformerBlock(hidden_dim=96, num_heads=6, mlp_ratio=3, dtype=torch.float32) + x = torch.randn(4, 8, 96) + out = model(x) + self.assertEqual(out.shape, (4, 8, 96)) + def test_create_model_and_input(self): m, k, n = 16, 64, 32 model, input_data = create_model_and_input( @@ -186,6 +244,63 @@ def test_create_model_and_input(self): self.assertIsInstance(model, LNLinearSigmoid) self.assertEqual(input_data.shape, (m, k)) + # Test RMSNormLinearActivation + model, input_data = create_model_and_input( + model_type="rms_norm_linear_activation", + m=m, + k=k, + n=n, + high_precision_dtype=torch.float32, + device="cpu", + ) + self.assertIsInstance(model, RMSNormLinearActivation) + self.assertEqual(input_data.shape, (m, k)) + + # Test TransformerBlock + model, input_data = create_model_and_input( + model_type="transformer_block", + m=m, + k=k, + n=n, # n is not used for transformer_block + high_precision_dtype=torch.float32, + device="cpu", + ) + self.assertIsInstance(model, TransformerBlock) + self.assertEqual(input_data.shape, (m, 16, k)) # [batch_size, seq_len, hidden_dim] + + def test_quantization_on_models(self): + # Test quantization on RMSNormLinearActivation + model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32) + x = torch.randn(16, 64) + + # Test with Int8WeightOnlyConfig + config = string_to_config(quantization="int8wo", sparsity=None) + if config is not None: + # Skip quantization test if torchao.quantization.quantize is not available + try: + from torchao.quantization import quantize + quantized_model = quantize(model, config) + out = quantized_model(x) + self.assertEqual(out.shape, (16, 32)) + except ImportError: + print("Skipping quantization test: torchao.quantization.quantize not available") + + # Test quantization on TransformerBlock + model = TransformerBlock(hidden_dim=64, num_heads=8, mlp_ratio=4, dtype=torch.float32) + x = torch.randn(16, 16, 64) + + # Test with Int8WeightOnlyConfig + config = string_to_config(quantization="int8wo", sparsity=None) + if config is not None: + # Skip quantization test if torchao.quantization.quantize is not available + try: + from torchao.quantization import quantize + quantized_model = quantize(model, config) + out = quantized_model(x) + self.assertEqual(out.shape, (16, 16, 64)) + except ImportError: + print("Skipping quantization test: torchao.quantization.quantize not available") + def test_generate_results_csv(self): results = [ BenchmarkResult( diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index 677f66ac75..9e978f70fa 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -383,6 +383,108 @@ def forward(self, x): return x +class RMSNorm(torch.nn.Module): + def __init__(self, dim, eps=1e-6, dtype=torch.bfloat16): + super().__init__() + self.eps = eps + self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) + + def forward(self, x): + norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + return x * norm * self.weight + + +class RMSNormLinearActivation(torch.nn.Module): + def __init__(self, fc_dim1, fc_dim2, dtype=torch.bfloat16, activation="gelu"): + super().__init__() + self.rms_norm = RMSNorm(fc_dim1, dtype=dtype) + self.fc = torch.nn.Linear(fc_dim1, fc_dim2, bias=False).to(dtype) + + if activation == "gelu": + self.activation = torch.nn.GELU() + elif activation == "relu": + self.activation = torch.nn.ReLU() + elif activation == "silu": + self.activation = torch.nn.SiLU() + else: + raise ValueError(f"Unsupported activation: {activation}") + + def forward(self, x): + x = self.rms_norm(x) + x = self.fc(x) + x = self.activation(x) + return x + + +class TransformerBlock(torch.nn.Module): + def __init__(self, hidden_dim, num_heads=8, mlp_ratio=4, dtype=torch.bfloat16): + super().__init__() + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.head_dim = hidden_dim // num_heads + + # Self-attention + self.qkv = torch.nn.Linear(hidden_dim, 3 * hidden_dim, bias=False).to(dtype) + self.proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=False).to(dtype) + + # MLP + self.mlp_ratio = mlp_ratio + self.mlp_hidden_dim = int(hidden_dim * mlp_ratio) + self.mlp_fc1 = torch.nn.Linear(hidden_dim, self.mlp_hidden_dim, bias=False).to(dtype) + self.mlp_fc2 = torch.nn.Linear(self.mlp_hidden_dim, hidden_dim, bias=False).to(dtype) + + # Layer norms + self.norm1 = RMSNorm(hidden_dim, dtype=dtype) + self.norm2 = RMSNorm(hidden_dim, dtype=dtype) + + # Activation + self.activation = torch.nn.GELU() + + def forward(self, x): + batch_size, seq_len, _ = x.shape + + # Self-attention + residual = x + x = self.norm1(x) + + # Reshape qkv projection for better memory layout + qkv = self.qkv(x) # [batch_size, seq_len, 3 * hidden_dim] + qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim) + qkv = qkv.permute(2, 0, 3, 1, 4) # [3, batch_size, num_heads, seq_len, head_dim] + q, k, v = qkv # Each has shape [batch_size, num_heads, seq_len, head_dim] + + # Scaled dot-product attention with proper reshaping + # Reshape for better memory layout and avoid broadcasting issues + q = q.reshape(batch_size * self.num_heads, seq_len, self.head_dim) + k = k.reshape(batch_size * self.num_heads, seq_len, self.head_dim) + v = v.reshape(batch_size * self.num_heads, seq_len, self.head_dim) + + # Compute attention scores + attn = (q @ k.transpose(-2, -1)) * (1.0 / (self.head_dim ** 0.5)) + attn = torch.softmax(attn, dim=-1) + + # Apply attention to values + x = attn @ v # [batch_size * num_heads, seq_len, head_dim] + + # Reshape back to original dimensions + x = x.reshape(batch_size, self.num_heads, seq_len, self.head_dim) + x = x.transpose(1, 2).reshape(batch_size, seq_len, self.hidden_dim) + + # Project back to hidden dimension + x = self.proj(x) + x = residual + x + + # MLP + residual = x + x = self.norm2(x) + x = self.mlp_fc1(x) + x = self.activation(x) + x = self.mlp_fc2(x) + x = residual + x + + return x + + def string_to_config( quantization: Optional[str], sparsity: Optional[str], **kwargs ) -> AOBaseConfig: @@ -576,6 +678,14 @@ def create_model_and_input( elif model_type == "ln_linear_sigmoid": model = LNLinearSigmoid(k, n, high_precision_dtype).to(device) input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype) + elif model_type == "rms_norm_linear_activation": + model = RMSNormLinearActivation(k, n, high_precision_dtype).to(device) + input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype) + elif model_type == "transformer_block": + # For transformer block, k is the hidden dimension + model = TransformerBlock(k, num_heads=8, mlp_ratio=4, dtype=high_precision_dtype).to(device) + # Input shape for transformer is [batch_size, seq_len, hidden_dim] + input_data = torch.randn(m, 16, k, device=device, dtype=high_precision_dtype) else: raise ValueError(f"Unknown model type: {model_type}") return model, input_data From 328b7bf8ee84028f3ce5b4767d9927f9384a6b35 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Fri, 4 Apr 2025 10:47:22 -0700 Subject: [PATCH 03/30] Update [ghstack-poisoned] --- .../microbenchmarks/test/benchmark_config.yml | 28 ----- benchmarks/microbenchmarks/test/test_utils.py | 115 ------------------ benchmarks/microbenchmarks/utils.py | 110 ----------------- 3 files changed, 253 deletions(-) diff --git a/benchmarks/microbenchmarks/test/benchmark_config.yml b/benchmarks/microbenchmarks/test/benchmark_config.yml index 4394d0208b..227cb90948 100644 --- a/benchmarks/microbenchmarks/test/benchmark_config.yml +++ b/benchmarks/microbenchmarks/test/benchmark_config.yml @@ -50,31 +50,3 @@ model_params: # device: "cpu" # model_type: "linear" # enable_profiler: true # Enable profiling for this model - - - name: "bf16_rms_norm_linear_activation" - matrix_shapes: - - name: "custom" - shapes: [ - [2048, 4096, 1024], - ] - high_precision_dtype: "torch.bfloat16" - use_torch_compile: true - torch_compile_mode: "max-autotune" - device: "cuda" - model_type: "rms_norm_linear_activation" - enable_profiler: true - enable_memory_profile: true - - - name: "bf16_transformer_block" - matrix_shapes: - - name: "custom" - shapes: [ - [2048, 4096, 1024], # For transformer_block, k is the hidden dimension - ] - high_precision_dtype: "torch.bfloat16" - use_torch_compile: true - torch_compile_mode: "max-autotune" - device: "cuda" - model_type: "transformer_block" - enable_profiler: true - enable_memory_profile: true diff --git a/benchmarks/microbenchmarks/test/test_utils.py b/benchmarks/microbenchmarks/test/test_utils.py index 46f6a74685..14f226bd7e 100644 --- a/benchmarks/microbenchmarks/test/test_utils.py +++ b/benchmarks/microbenchmarks/test/test_utils.py @@ -17,11 +17,8 @@ Float8DynamicActivationFloat8SemiSparseWeightConfig, Int4WeightOnlyConfig, LNLinearSigmoid, - RMSNorm, - RMSNormLinearActivation, SemiSparseWeightConfig, ToyLinearModel, - TransformerBlock, clean_caches, create_model_and_input, generate_results_csv, @@ -165,61 +162,6 @@ def test_ln_linear_sigmoid(self): torch.all((out >= 0) & (out <= 1)) ) # Check sigmoid output range - def test_rms_norm(self): - # Test RMSNorm - rms_norm = RMSNorm(dim=64) - x = torch.randn(16, 64) - out = rms_norm(x) - self.assertEqual(out.shape, (16, 64)) - - # Test with different eps - rms_norm = RMSNorm(dim=64, eps=1e-5) - out = rms_norm(x) - self.assertEqual(out.shape, (16, 64)) - - def test_rms_norm_linear_activation(self): - # Test with default GELU activation - model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32) - x = torch.randn(16, 64) - out = model(x) - self.assertEqual(out.shape, (16, 32)) - self.assertEqual(out.dtype, torch.float32) - - # Test with ReLU activation - model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="relu") - out = model(x) - self.assertEqual(out.shape, (16, 32)) - self.assertTrue(torch.all(out >= 0)) # Check ReLU output range - - # Test with SiLU activation - model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="silu") - out = model(x) - self.assertEqual(out.shape, (16, 32)) - - # Test with invalid activation - with self.assertRaises(ValueError): - RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="invalid") - - def test_transformer_block(self): - # Test with default parameters - model = TransformerBlock(hidden_dim=64, num_heads=8, mlp_ratio=4, dtype=torch.float32) - x = torch.randn(16, 16, 64) # [batch_size, seq_len, hidden_dim] - out = model(x) - self.assertEqual(out.shape, (16, 16, 64)) - self.assertEqual(out.dtype, torch.float32) - - # Test with different parameters - model = TransformerBlock(hidden_dim=128, num_heads=4, mlp_ratio=2, dtype=torch.float32) - x = torch.randn(8, 32, 128) - out = model(x) - self.assertEqual(out.shape, (8, 32, 128)) - - # Test with different head dimensions - model = TransformerBlock(hidden_dim=96, num_heads=6, mlp_ratio=3, dtype=torch.float32) - x = torch.randn(4, 8, 96) - out = model(x) - self.assertEqual(out.shape, (4, 8, 96)) - def test_create_model_and_input(self): m, k, n = 16, 64, 32 model, input_data = create_model_and_input( @@ -244,63 +186,6 @@ def test_create_model_and_input(self): self.assertIsInstance(model, LNLinearSigmoid) self.assertEqual(input_data.shape, (m, k)) - # Test RMSNormLinearActivation - model, input_data = create_model_and_input( - model_type="rms_norm_linear_activation", - m=m, - k=k, - n=n, - high_precision_dtype=torch.float32, - device="cpu", - ) - self.assertIsInstance(model, RMSNormLinearActivation) - self.assertEqual(input_data.shape, (m, k)) - - # Test TransformerBlock - model, input_data = create_model_and_input( - model_type="transformer_block", - m=m, - k=k, - n=n, # n is not used for transformer_block - high_precision_dtype=torch.float32, - device="cpu", - ) - self.assertIsInstance(model, TransformerBlock) - self.assertEqual(input_data.shape, (m, 16, k)) # [batch_size, seq_len, hidden_dim] - - def test_quantization_on_models(self): - # Test quantization on RMSNormLinearActivation - model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32) - x = torch.randn(16, 64) - - # Test with Int8WeightOnlyConfig - config = string_to_config(quantization="int8wo", sparsity=None) - if config is not None: - # Skip quantization test if torchao.quantization.quantize is not available - try: - from torchao.quantization import quantize - quantized_model = quantize(model, config) - out = quantized_model(x) - self.assertEqual(out.shape, (16, 32)) - except ImportError: - print("Skipping quantization test: torchao.quantization.quantize not available") - - # Test quantization on TransformerBlock - model = TransformerBlock(hidden_dim=64, num_heads=8, mlp_ratio=4, dtype=torch.float32) - x = torch.randn(16, 16, 64) - - # Test with Int8WeightOnlyConfig - config = string_to_config(quantization="int8wo", sparsity=None) - if config is not None: - # Skip quantization test if torchao.quantization.quantize is not available - try: - from torchao.quantization import quantize - quantized_model = quantize(model, config) - out = quantized_model(x) - self.assertEqual(out.shape, (16, 16, 64)) - except ImportError: - print("Skipping quantization test: torchao.quantization.quantize not available") - def test_generate_results_csv(self): results = [ BenchmarkResult( diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index 9e978f70fa..677f66ac75 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -383,108 +383,6 @@ def forward(self, x): return x -class RMSNorm(torch.nn.Module): - def __init__(self, dim, eps=1e-6, dtype=torch.bfloat16): - super().__init__() - self.eps = eps - self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) - - def forward(self, x): - norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - return x * norm * self.weight - - -class RMSNormLinearActivation(torch.nn.Module): - def __init__(self, fc_dim1, fc_dim2, dtype=torch.bfloat16, activation="gelu"): - super().__init__() - self.rms_norm = RMSNorm(fc_dim1, dtype=dtype) - self.fc = torch.nn.Linear(fc_dim1, fc_dim2, bias=False).to(dtype) - - if activation == "gelu": - self.activation = torch.nn.GELU() - elif activation == "relu": - self.activation = torch.nn.ReLU() - elif activation == "silu": - self.activation = torch.nn.SiLU() - else: - raise ValueError(f"Unsupported activation: {activation}") - - def forward(self, x): - x = self.rms_norm(x) - x = self.fc(x) - x = self.activation(x) - return x - - -class TransformerBlock(torch.nn.Module): - def __init__(self, hidden_dim, num_heads=8, mlp_ratio=4, dtype=torch.bfloat16): - super().__init__() - self.hidden_dim = hidden_dim - self.num_heads = num_heads - self.head_dim = hidden_dim // num_heads - - # Self-attention - self.qkv = torch.nn.Linear(hidden_dim, 3 * hidden_dim, bias=False).to(dtype) - self.proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=False).to(dtype) - - # MLP - self.mlp_ratio = mlp_ratio - self.mlp_hidden_dim = int(hidden_dim * mlp_ratio) - self.mlp_fc1 = torch.nn.Linear(hidden_dim, self.mlp_hidden_dim, bias=False).to(dtype) - self.mlp_fc2 = torch.nn.Linear(self.mlp_hidden_dim, hidden_dim, bias=False).to(dtype) - - # Layer norms - self.norm1 = RMSNorm(hidden_dim, dtype=dtype) - self.norm2 = RMSNorm(hidden_dim, dtype=dtype) - - # Activation - self.activation = torch.nn.GELU() - - def forward(self, x): - batch_size, seq_len, _ = x.shape - - # Self-attention - residual = x - x = self.norm1(x) - - # Reshape qkv projection for better memory layout - qkv = self.qkv(x) # [batch_size, seq_len, 3 * hidden_dim] - qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim) - qkv = qkv.permute(2, 0, 3, 1, 4) # [3, batch_size, num_heads, seq_len, head_dim] - q, k, v = qkv # Each has shape [batch_size, num_heads, seq_len, head_dim] - - # Scaled dot-product attention with proper reshaping - # Reshape for better memory layout and avoid broadcasting issues - q = q.reshape(batch_size * self.num_heads, seq_len, self.head_dim) - k = k.reshape(batch_size * self.num_heads, seq_len, self.head_dim) - v = v.reshape(batch_size * self.num_heads, seq_len, self.head_dim) - - # Compute attention scores - attn = (q @ k.transpose(-2, -1)) * (1.0 / (self.head_dim ** 0.5)) - attn = torch.softmax(attn, dim=-1) - - # Apply attention to values - x = attn @ v # [batch_size * num_heads, seq_len, head_dim] - - # Reshape back to original dimensions - x = x.reshape(batch_size, self.num_heads, seq_len, self.head_dim) - x = x.transpose(1, 2).reshape(batch_size, seq_len, self.hidden_dim) - - # Project back to hidden dimension - x = self.proj(x) - x = residual + x - - # MLP - residual = x - x = self.norm2(x) - x = self.mlp_fc1(x) - x = self.activation(x) - x = self.mlp_fc2(x) - x = residual + x - - return x - - def string_to_config( quantization: Optional[str], sparsity: Optional[str], **kwargs ) -> AOBaseConfig: @@ -678,14 +576,6 @@ def create_model_and_input( elif model_type == "ln_linear_sigmoid": model = LNLinearSigmoid(k, n, high_precision_dtype).to(device) input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype) - elif model_type == "rms_norm_linear_activation": - model = RMSNormLinearActivation(k, n, high_precision_dtype).to(device) - input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype) - elif model_type == "transformer_block": - # For transformer block, k is the hidden dimension - model = TransformerBlock(k, num_heads=8, mlp_ratio=4, dtype=high_precision_dtype).to(device) - # Input shape for transformer is [batch_size, seq_len, hidden_dim] - input_data = torch.randn(m, 16, k, device=device, dtype=high_precision_dtype) else: raise ValueError(f"Unknown model type: {model_type}") return model, input_data From acc3c791346f0170279ba076a6e66609c0a0b3e3 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Tue, 1 Apr 2025 13:43:54 -0700 Subject: [PATCH 04/30] Reintroduce has_weight_zeros as a template param Differential Revision: D71503133 Pull Request resolved: https://github.com/pytorch/ao/pull/1991 --- ..._8bit_activation_groupwise_lowbit_weight.h | 8 +-- .../kernel_1x8x16_f32_neondot-impl.h | 5 +- .../kernels/cpu/aarch64/linear/linear.h | 7 +- .../kernels/cpu/aarch64/tests/test_linear.cpp | 44 ++++++------ .../embedding_xbit/op_embedding_xbit-impl.h | 13 ++-- .../kernel_selector.h | 71 +++++++++++++------ .../packed_weights_format.h | 6 +- .../test_linear_8bit_act_xbit_weight.cpp | 2 +- 8 files changed, 92 insertions(+), 64 deletions(-) diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h index 4ca9cef54d..9ff75e3344 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h +++ b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h @@ -245,7 +245,7 @@ void kernel_1x4x16_f32_neondot( has_clamp); } -template +template void kernel_1x8x16_f32_neondot( // Outputs float32_t* output, @@ -260,10 +260,11 @@ void kernel_1x8x16_f32_neondot( // Ignored if has_clamp = false float clamp_min, float clamp_max, - bool has_weight_zeros, + bool has_weight_zeros_, bool has_bias, bool has_clamp) { - kernel::kernel_1x8x16_f32_neondot( + (void)has_weight_zeros_; // unused + kernel::kernel_1x8x16_f32_neondot( output, output_m_stride, m, @@ -274,7 +275,6 @@ void kernel_1x8x16_f32_neondot( packed_activations, clamp_min, clamp_max, - has_weight_zeros, has_bias, has_clamp); } diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x8x16_f32_neondot-impl.h b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x8x16_f32_neondot-impl.h index 81f6e6b023..7a53c7302c 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x8x16_f32_neondot-impl.h +++ b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x8x16_f32_neondot-impl.h @@ -58,7 +58,7 @@ vec_clamp(float32x4_t x, float32x4_t vec_min, float32x4_t vec_max) { // Roughly inspired by // https://gitlab.arm.com/kleidi/kleidiai/-/blob/main/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c?ref_type=heads -template +template void kernel_1x8x16_f32_neondot( // Outputs float32_t* output, @@ -73,7 +73,6 @@ void kernel_1x8x16_f32_neondot( // Ignored if has_clamp is false float clamp_min, float clamp_max, - bool has_weight_zeros, bool has_bias, bool has_clamp) { assert(k % group_size == 0); @@ -267,7 +266,7 @@ void kernel_1x8x16_f32_neondot( int32x4_t term1_4567 = vmulq_n_s32(weight_qvals_sum, activation_zero); - if (has_weight_zeros) { + if constexpr (has_weight_zeros) { // Compute term2 and term3 int32_t activation_qvals_sum = *((int32_t*)activation_ptr); diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/linear.h b/torchao/experimental/kernels/cpu/aarch64/linear/linear.h index cd816dba46..7b983a1929 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/linear.h +++ b/torchao/experimental/kernels/cpu/aarch64/linear/linear.h @@ -320,7 +320,7 @@ void prepare_weight_data( bias); } -template +template void kernel( // Outputs float32_t* output, @@ -335,12 +335,13 @@ void kernel( // Ignored if has_clamp = false float clamp_min, float clamp_max, - bool has_weight_zeros, + bool has_weight_zeros_, bool has_bias, bool has_clamp) { + (void)has_weight_zeros_; // unused torchao::kernels::cpu::aarch64::linear:: channelwise_8bit_activation_groupwise_lowbit_weight:: - kernel_1x8x16_f32_neondot( + kernel_1x8x16_f32_neondot( output, output_m_stride, m, diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp index 2e19a524e5..0157769fec 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp @@ -311,7 +311,7 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot bias_ptr); std::vector output(m * n); - kernel( + kernel( output.data(), /*output_m_stride=*/n, m, @@ -388,13 +388,12 @@ TEST( } } -template +template void test_channelwise_8bit_activation_groupwise_lowbit_weight_lut( int m, int k, int n, int group_size, - bool has_weight_zeros, bool has_bias, bool has_clamp) { constexpr int mr = 1; @@ -453,7 +452,7 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_lut( has_bias ? test_case.bias.data() : nullptr); std::vector output(m * n); - kernel_1x8x16_f32_neondot( + kernel_1x8x16_f32_neondot( output.data(), /*output_m_stride=*/n, m, @@ -476,85 +475,90 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_lut( TEST(test_channelwise_8bit_activation_groupwise_lowbit_weight, LUT) { constexpr int weight_nbit = 4; - test_channelwise_8bit_activation_groupwise_lowbit_weight_lut( + test_channelwise_8bit_activation_groupwise_lowbit_weight_lut< + weight_nbit, + /*has_weight_zeros*/ false>( /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16, - /*has_weight_zeros=*/false, /*has_bias=*/false, /*has_clamp=*/false); // has_weight_zeros - test_channelwise_8bit_activation_groupwise_lowbit_weight_lut( + test_channelwise_8bit_activation_groupwise_lowbit_weight_lut< + weight_nbit, + /*has_weight_zeros*/ true>( /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16, - /*has_weight_zeros=*/true, /*has_bias=*/false, /*has_clamp=*/false); // has_bias - test_channelwise_8bit_activation_groupwise_lowbit_weight_lut( + test_channelwise_8bit_activation_groupwise_lowbit_weight_lut< + weight_nbit, + /*has_weight_zeros=*/false>( /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16, - /*has_weight_zeros=*/false, /*has_bias=*/true, /*has_clamp=*/false); // has_clamp - test_channelwise_8bit_activation_groupwise_lowbit_weight_lut( + test_channelwise_8bit_activation_groupwise_lowbit_weight_lut< + weight_nbit, + /*has_weight_zeros*/ false>( /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16, - /*has_weight_zeros=*/false, /*has_bias=*/false, /*has_clamp=*/true); // n less than 8 (nr) for (int n = 1; n < 8; n++) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_lut( + test_channelwise_8bit_activation_groupwise_lowbit_weight_lut< + weight_nbit, + /*has_weight_zeros=*/false>( /*m=*/7, /*k=*/64, /*n=*/n, /*group_size=*/16, - /*has_weight_zeros=*/false, /*has_bias=*/false, /*has_clamp=*/false); } // Other bitwidths test_channelwise_8bit_activation_groupwise_lowbit_weight_lut< - /*weight_nbit*/ 1>( + /*weight_nbit*/ 1, + /*has_weight_zeros=*/false>( /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16, - /*has_weight_zeros=*/false, /*has_bias=*/false, /*has_clamp=*/false); test_channelwise_8bit_activation_groupwise_lowbit_weight_lut< - /*weight_nbit*/ 2>( + /*weight_nbit*/ 2, + /*has_weight_zeros=*/false>( /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16, - /*has_weight_zeros=*/false, /*has_bias=*/false, /*has_clamp=*/false); test_channelwise_8bit_activation_groupwise_lowbit_weight_lut< - /*weight_nbit*/ 3>( + /*weight_nbit*/ 3, + /*has_weight_zeros=*/false>( /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16, - /*has_weight_zeros=*/false, /*has_bias=*/false, /*has_clamp=*/false); } diff --git a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h b/torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h index 22b87cfb9e..8113a0566b 100644 --- a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h +++ b/torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h @@ -253,9 +253,11 @@ Tensor shared_embedding_out_cpu( torchao::ops::PackedWeightsHeader::read(packed_weights.const_data_ptr()); auto format = torchao::ops::linear_8bit_act_xbit_weight::PackedWeightsFormat:: from_packed_weights_header(header); - torchao::ops::linear_8bit_act_xbit_weight::check_format( + + torchao::ops::linear_8bit_act_xbit_weight::check_format( format, - torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal); + torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal, + weight_nbit); constexpr int nr = 8; constexpr int kr = 16; constexpr int sr = 2; @@ -316,12 +318,7 @@ Tensor shared_embedding_cpu( const Tensor& indices) { Tensor output_tensor = torch::empty({}, torch::kFloat32); shared_embedding_out_cpu( - packed_weights, - group_size, - n, - k, - indices, - output_tensor); + packed_weights, group_size, n, k, indices, output_tensor); return output_tensor; } #endif // USE_ATEN diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h index 17d7ec13b1..e960a918d8 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h @@ -89,9 +89,11 @@ void register_ukernel_config_universal( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - check_format( + + check_format( format, - torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal); + torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal, + weight_nbit); if (format.nr == 8 && format.kr == 16 && format.sr == 2) { #if defined(TORCHAO_BUILD_CPU_AARCH64) @@ -99,25 +101,50 @@ void register_ukernel_config_universal( log_registration(format, "universal"); namespace kernel = torchao::kernels::cpu::aarch64::linear:: channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; - table.register_ukernel_config( - format, - uarch, - UKernelConfig{ - /*preferred_alignment*/ 16, - /*nr*/ 8, - /*weight_packing_config*/ - {/*weight_data_size_fn*/ - &kernel::weight_data_size, - /*prepare_weight_data_fn*/ - &kernel::prepare_weight_data}, - /*linear_configs*/ - {{{/*mr*/ 1, - /*activation_data_size_fn*/ - &kernel::activation_data_size, - /*prepare_activation_data_fn*/ - &kernel::prepare_activation_data, - /*kernel*/ - &kernel::kernel}}}}); + + if (format.has_weight_zeros) { + constexpr bool has_weight_zeros = true; + table.register_ukernel_config( + format, + uarch, + UKernelConfig{ + /*preferred_alignment*/ 16, + /*nr*/ 8, + /*weight_packing_config*/ + {/*weight_data_size_fn*/ + &kernel::weight_data_size, + /*prepare_weight_data_fn*/ + &kernel::prepare_weight_data}, + /*linear_configs*/ + {{{/*mr*/ 1, + /*activation_data_size_fn*/ + &kernel::activation_data_size, + /*prepare_activation_data_fn*/ + &kernel::prepare_activation_data, + /*kernel*/ + &kernel::kernel}}}}); + } else { + constexpr bool has_weight_zeros = false; + table.register_ukernel_config( + format, + uarch, + UKernelConfig{ + /*preferred_alignment*/ 16, + /*nr*/ 8, + /*weight_packing_config*/ + {/*weight_data_size_fn*/ + &kernel::weight_data_size, + /*prepare_weight_data_fn*/ + &kernel::prepare_weight_data}, + /*linear_configs*/ + {{{/*mr*/ 1, + /*activation_data_size_fn*/ + &kernel::activation_data_size, + /*prepare_activation_data_fn*/ + &kernel::prepare_activation_data, + /*kernel*/ + &kernel::kernel}}}}); + } return; } #endif // TORCHAO_BUILD_CPU_AARCH64 @@ -166,7 +193,7 @@ void register_ukernel_config_kleidi( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - check_format(format, torchao::ops::PackedWeightsType::kleidi_ai); + check_format(format, torchao::ops::PackedWeightsType::kleidi_ai, weight_nbit); namespace op = torchao::kernels::cpu::aarch64::kleidi:: kai_matmul_clamp_f32_qai8dxp_qsi4c32p; diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_format.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_format.h index 82beea43fb..e22082f9f1 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_format.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_format.h @@ -53,10 +53,10 @@ struct PackedWeightsFormat { } }; -template -void check_format( +inline void check_format( PackedWeightsFormat format, - torchao::ops::PackedWeightsType type) { + torchao::ops::PackedWeightsType type, + int weight_nbit) { if (format.type != type) { throw std::runtime_error( "Kernel expects packed_weights type=" + diff --git a/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp b/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp index ae11b56e42..caaf8baf74 100644 --- a/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp +++ b/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp @@ -42,7 +42,7 @@ UKernelConfig get_ukernel_config() { /*prepare_activation_data_fn*/ &kernel::prepare_activation_data, /*kernel*/ - &kernel::kernel}}}}; + &kernel::kernel}}}}; } template < From 77c4ef194dc64a776885d9f37350088b72afd9bd Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Tue, 1 Apr 2025 16:09:36 -0700 Subject: [PATCH 05/30] Claen up op interface Differential Revision: D72179480 Pull Request resolved: https://github.com/pytorch/ao/pull/1998 --- .../kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h | 95 +++- ..._8bit_activation_groupwise_lowbit_weight.h | 24 +- .../kernels/cpu/aarch64/linear/linear.h | 365 ------------- .../kernels/cpu/aarch64/tests/test_linear.cpp | 485 ++++++++++-------- .../kernel_config.h | 238 +++++++++ .../kernel_selector.h | 230 ++++----- .../linear_8bit_act_xbit_weight.cpp | 416 +++++---------- .../linear_8bit_act_xbit_weight.h | 144 +----- .../op_linear_8bit_act_xbit_weight-impl.h | 112 ++-- .../test_linear_8bit_act_xbit_weight.cpp | 326 +++++------- 10 files changed, 1006 insertions(+), 1429 deletions(-) delete mode 100644 torchao/experimental/kernels/cpu/aarch64/linear/linear.h create mode 100644 torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_config.h diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h index 2e8d0aa453..2a8e668fa7 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h @@ -60,27 +60,47 @@ namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { using Ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel; -template -size_t -activation_data_size(int m, int k, int group_size, bool has_weight_zeros) { +size_t packed_activations_size( + int m, + int k, + int group_size, + bool has_weight_zeros, + int mr, + int kr, + int sr) { (void)group_size; // unused (void)has_weight_zeros; // unused auto lhs_packing = get_lhs_packing(); return lhs_packing.get_lhs_packed_size(m, k, mr, kr, sr); } -template -void prepare_activation_data( - void* activation_data, +size_t packed_activations_offset( + int m_idx, + int k, + int group_size, + bool has_weight_zeros, + int mr, + int kr, + int sr) { + (void)group_size; // unused + (void)has_weight_zeros; // unused + auto lhs_pack = get_lhs_packing(); + return lhs_pack.get_lhs_packed_offset(m_idx, k, mr, kr, sr); +} + +void pack_activations( + void* packed_activations, int m, int k, int group_size, const float* activations, - bool has_weight_zeros) { + bool has_weight_zeros, + int mr, + int kr, + int sr) { (void)group_size; // unused (void)has_weight_zeros; // unused auto lhs_pack = get_lhs_packing(); - lhs_pack.run_lhs_pack( m, k, @@ -90,33 +110,62 @@ void prepare_activation_data( /*m_index_start=*/0, activations, /*lhs_stride=*/k * sizeof(float), - activation_data); + packed_activations); } -template -size_t weight_data_size( +size_t packed_weights_size( int n, int k, int group_size, + int weight_nbit, bool has_weight_zeros, - bool has_bias) { + bool has_bias, + int nr, + int kr, + int sr) { + (void)weight_nbit; // unused (void)has_weight_zeros; // unused (void)has_bias; // unused auto rhs_pack = get_rhs_packing(); return rhs_pack.get_rhs_packed_size( - n, k, nr, kr, sr, group_size, kai_datatype::kai_dt_bf16); + internal::adjust_n(n), + k, + nr, + kr, + sr, + group_size, + kai_datatype::kai_dt_bf16); +} + +size_t packed_weights_offset( + int n_idx, + int k, + int group_size, + int weight_nbit, + bool has_weight_zeros, + bool has_bias, + int nr, + int kr, + int sr) { + (void)has_weight_zeros; // unused + (void)has_bias; // unused + auto rhs_pack = get_rhs_packing(); + return rhs_pack.get_rhs_packed_offset( + n_idx, k, nr, kr, sr, group_size, kai_datatype::kai_dt_bf16); } -template -void prepare_weight_data( - void* weight_data, +void pack_weights( + void* packed_weights, int n, int k, int group_size, const int8_t* weight_qvals, const float* weight_scales, const int8_t* weight_zeros, - const float* bias) { + const float* bias, + int nr, + int kr, + int sr) { if (group_size % 32 != 0) { throw std::runtime_error( "Group size must be a multiple of 32, but got group_size=" + @@ -187,7 +236,7 @@ void prepare_weight_data( reinterpret_cast(weight_scales_bf16_padded.data()), /*scale_stride=*/sizeof(uint16_t) * (internal::roundup(k, group_size) / group_size), - /*rhs_packed=*/weight_data, + /*rhs_packed=*/packed_weights, /*extra_bytes=*/0, /*qparams=*/&qparams); } @@ -220,8 +269,8 @@ size_t get_preferred_alignement() { int n, \ int k, \ int group_size, \ - const void* weight_data, \ - const void* activation_data, \ + const void* packed_weights, \ + const void* packed_activations, \ float clamp_min, \ float clamp_max, \ bool has_weight_zeros, \ @@ -235,11 +284,11 @@ size_t get_preferred_alignement() { } \ get_ukernel().run_matmul( \ m, \ - internal::adjust_n(n), \ + n, \ k, \ group_size, \ - activation_data, \ - weight_data, \ + packed_activations, \ + packed_weights, \ output, \ /*dst_stride_row=*/output_m_stride * sizeof(float), \ /*dst_stride_col=*/sizeof(float), \ diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h index 9ff75e3344..95ecb79dc0 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h +++ b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h @@ -49,15 +49,21 @@ inline size_t packed_activations_offset( return (m_idx / mr) * packed_activations_size_mr_rows; } -template +template void pack_activations( void* packed_activations, int m, int k, int group_size, const float* activations, - bool has_weight_zeros) { - activation_packing::pack_activations( + bool has_weight_zeros, + int mr, + int kr, + int sr) { + (void)mr; // unused + (void)kr; // unused + (void)sr; // unused + activation_packing::pack_activations( packed_activations, m, k, group_size, activations, has_weight_zeros); } @@ -93,7 +99,7 @@ inline size_t packed_weights_offset( return (n_idx / nr) * packed_weights_size_nr_cols; } -template +template void pack_weights( void* packed_weights, int n, @@ -102,8 +108,14 @@ void pack_weights( const int8_t* weight_qvals, const float* weight_scales, const int8_t* weight_zeros, - const float* bias) { - weight_packing::pack_weights( + const float* bias, + int nr, + int kr, + int sr) { + (void)nr; // unused + (void)kr; // unused + (void)sr; // unused + weight_packing::pack_weights( packed_weights, n, k, diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/linear.h b/torchao/experimental/kernels/cpu/aarch64/linear/linear.h deleted file mode 100644 index 7b983a1929..0000000000 --- a/torchao/experimental/kernels/cpu/aarch64/linear/linear.h +++ /dev/null @@ -1,365 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - -// TODO: this file will be deleted and replaced by -// torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/include.h -// It exists now to prevent breaking existing code in the interim. - -#pragma once - -#if defined(__aarch64__) || defined(__ARM_NEON) - -#include -#include -#include - -namespace torchao::kernels::cpu::aarch64::linear { -namespace channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot { - -inline size_t -activation_data_size(int m, int k, int group_size, bool has_weight_zeros) { - return torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - packed_activations_size( - m, - k, - group_size, - has_weight_zeros, - /*mr*/ 1, - /*kr*/ 32, - /*sr*/ 1); -} - -inline void prepare_activation_data( - void* activation_data, - // Inputs - int m, - int k, - // Ignored if has_weight_zeros = false - int group_size, - const float* activations, - bool has_weight_zeros) { - torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - pack_activations( - activation_data, m, k, group_size, activations, has_weight_zeros); -} - -template -size_t weight_data_size( - int n, - int k, - int group_size, - bool has_weight_zeros, - bool has_bias) { - return torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight::packed_weights_size( - n, - k, - group_size, - weight_nbit, - has_weight_zeros, - has_bias, - /*nr*/ 1, - /*kr*/ 32, - /*sr*/ 1); -} - -template -void prepare_weight_data( - void* weight_data, - // Inputs - int n, - int k, - int group_size, - const int8_t* weight_qvals, - const float* weight_scales, - const int8_t* weight_zeros, - const float* bias) { - torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - pack_weights( - weight_data, - n, - k, - group_size, - weight_qvals, - weight_scales, - weight_zeros, - bias); -} - -template -void kernel( - // Outputs - float32_t* output, - // Inputs - int output_m_stride, - int m, - int n, - int k, - int group_size, - const void* weight_data, - const void* activation_data, - // Ignored if has_clamp = false - float clamp_min, - float clamp_max, - bool has_weight_zeros, - bool has_bias, - bool has_clamp) { - torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - kernel_1x1x32_f32_neondot( - output, - output_m_stride, - m, - n, - k, - group_size, - weight_data, - activation_data, - clamp_min, - clamp_max, - has_weight_zeros, - has_bias, - has_clamp); -} - -} // namespace - // channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot - -namespace channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot { - -inline size_t -activation_data_size(int m, int k, int group_size, bool has_weight_zeros) { - return torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - packed_activations_size( - m, - k, - group_size, - has_weight_zeros, - /*mr*/ 1, - /*kr*/ 16, - /*sr*/ 2); -} - -inline void prepare_activation_data( - void* activation_data, - // Inputs - int m, - int k, - // Ignored if has_weight_zeros = false - int group_size, - const float* activations, - bool has_weight_zeros) { - torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - pack_activations( - activation_data, m, k, group_size, activations, has_weight_zeros); -} - -template -inline size_t weight_data_size( - int n, - int k, - int group_size, - bool has_weight_zeros, - bool has_bias) { - return torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight::packed_weights_size( - n, - k, - group_size, - weight_nbit, - has_weight_zeros, - has_bias, - /*nr*/ 4, - /*kr*/ 16, - /*sr*/ 2); -} - -template -inline void prepare_weight_data( - void* weight_data, - // Inputs - int n, - int k, - int group_size, - const int8_t* weight_qvals, - const float* weight_scales, - const int8_t* weight_zeros, - const float* bias) { - torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - pack_weights( - weight_data, - n, - k, - group_size, - weight_qvals, - weight_scales, - weight_zeros, - bias); -} - -template -void kernel( - // Outputs - float32_t* output, - // Inputs - int output_m_stride, - int m, - int n, - int k, - int group_size, - const void* weight_data, - const void* activation_data, - // Ignored if has_clamp = false - float clamp_min, - float clamp_max, - bool has_weight_zeros, - bool has_bias, - bool has_clamp) { - torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - kernel_1x4x16_f32_neondot( - output, - output_m_stride, - m, - n, - k, - group_size, - weight_data, - activation_data, - clamp_min, - clamp_max, - has_weight_zeros, - has_bias, - has_clamp); -} - -} // namespace - // channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot - -namespace channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot { - -inline size_t -activation_data_size(int m, int k, int group_size, bool has_weight_zeros) { - return torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - packed_activations_size( - m, - k, - group_size, - has_weight_zeros, - /*mr*/ 1, - /*kr*/ 16, - /*sr*/ 2); -} - -inline void prepare_activation_data( - void* activation_data, - // Inputs - int m, - int k, - // Ignored if has_weight_zeros = false - int group_size, - const float* activations, - bool has_weight_zeros) { - torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - pack_activations( - activation_data, m, k, group_size, activations, has_weight_zeros); -} - -template -size_t weight_data_size( - int n, - int k, - int group_size, - bool has_weight_zeros, - bool has_bias) { - return torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight::packed_weights_size( - n, - k, - group_size, - weight_nbit, - has_weight_zeros, - has_bias, - /*nr*/ 8, - /*kr*/ 16, - /*sr*/ 2); -} - -template -void prepare_weight_data( - void* weight_data, - // Inputs - int n, - int k, - int group_size, - const int8_t* weight_qvals, - const float* weight_scales, - const int8_t* weight_zeros, - const float* bias) { - torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - pack_weights( - weight_data, - n, - k, - group_size, - weight_qvals, - weight_scales, - weight_zeros, - bias); -} - -template -void kernel( - // Outputs - float32_t* output, - // Inputs - int output_m_stride, - int m, - int n, - int k, - int group_size, - const void* weight_data, - const void* activation_data, - // Ignored if has_clamp = false - float clamp_min, - float clamp_max, - bool has_weight_zeros_, - bool has_bias, - bool has_clamp) { - (void)has_weight_zeros_; // unused - torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - kernel_1x8x16_f32_neondot( - output, - output_m_stride, - m, - n, - k, - group_size, - weight_data, - activation_data, - clamp_min, - clamp_max, - has_weight_zeros, - has_bias, - has_clamp); -} - -} // namespace - // channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot - -} // namespace torchao::kernels::cpu::aarch64::linear - -#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp index 0157769fec..671ee3f0b9 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp @@ -12,17 +12,23 @@ #include #include #include -#include #include float kTol = 0.0001; -template -void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot( +template +void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32( int m, int k, int n, - int group_size) { + int group_size, + bool has_bias, + bool has_clamp) { + constexpr int mr = 1; + constexpr int nr = 1; + constexpr int kr = 32; + constexpr int sr = 1; + auto test_case = torchao:: channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( m, @@ -35,48 +41,46 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot has_clamp); using namespace torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot; + channelwise_8bit_activation_groupwise_lowbit_weight; - std::vector activation_data( - activation_data_size(m, k, group_size, has_weight_zeros)); - prepare_activation_data( - (void*)activation_data.data(), + std::vector packed_activations( + packed_activations_size(m, k, group_size, has_weight_zeros, mr, kr, sr)); + pack_activations( + (void*)packed_activations.data(), m, k, group_size, test_case.activations.data(), - has_weight_zeros); + has_weight_zeros, + mr, + kr, + sr); - std::vector weight_data(weight_data_size( - n, k, group_size, has_weight_zeros, has_bias)); - int8_t* weight_zeros_ptr = nullptr; - if (has_weight_zeros) { - weight_zeros_ptr = test_case.weight_zeros.data(); - } - float* bias_ptr = nullptr; - if (has_bias) { - bias_ptr = test_case.bias.data(); - } - prepare_weight_data( - (void*)weight_data.data(), + std::vector packed_weights(packed_weights_size( + n, k, group_size, weight_nbit, has_weight_zeros, has_bias, nr, kr, sr)); + pack_weights( + (void*)packed_weights.data(), n, k, group_size, test_case.weight_qvals.data(), test_case.weight_scales.data(), - /*weight_zeros=*/weight_zeros_ptr, - bias_ptr); + has_weight_zeros ? test_case.weight_zeros.data() : nullptr, + has_bias ? test_case.bias.data() : nullptr, + nr, + kr, + sr); std::vector output(m * n); - kernel( + kernel_1x1x32_f32_neondot( output.data(), /*output_m_stride=*/n, m, n, k, group_size, - weight_data.data(), - activation_data.data(), + packed_weights.data(), + packed_activations.data(), /*clamp_min=*/test_case.clamp_min, /*clamp_max=*/test_case.clamp_max, has_weight_zeros, @@ -88,56 +92,19 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot } } -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot, - Standard) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/128, /*n=*/13, /*group_size=*/32); -} - -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot, - HasWeightZeros) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot< - 4 /*weight_nbit*/, - true /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/128, /*n=*/13, /*group_size=*/32); -} - -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot, - HasBias) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/128, /*n=*/13, /*group_size=*/32); -} - -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot, - HasClamp) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/>( - /*m=*/7, /*k=*/128, /*n=*/13, /*group_size=*/32); -} - -template -void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot( +template +void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16( int m, int k, int n, - int group_size) { + int group_size, + bool has_bias, + bool has_clamp) { + constexpr int mr = 1; + constexpr int nr = 4; + constexpr int kr = 16; + constexpr int sr = 2; + auto test_case = torchao:: channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( m, @@ -150,48 +117,46 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot has_clamp); using namespace torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot; + channelwise_8bit_activation_groupwise_lowbit_weight; - std::vector activation_data( - activation_data_size(m, k, group_size, has_weight_zeros)); - prepare_activation_data( - (void*)activation_data.data(), + std::vector packed_activations( + packed_activations_size(m, k, group_size, has_weight_zeros, mr, kr, sr)); + pack_activations( + (void*)packed_activations.data(), m, k, group_size, test_case.activations.data(), - has_weight_zeros); + has_weight_zeros, + mr, + kr, + sr); - std::vector weight_data(weight_data_size( - n, k, group_size, has_weight_zeros, has_bias)); - int8_t* weight_zeros_ptr = nullptr; - if (has_weight_zeros) { - weight_zeros_ptr = test_case.weight_zeros.data(); - } - float* bias_ptr = nullptr; - if (has_bias) { - bias_ptr = test_case.bias.data(); - } - prepare_weight_data( - (void*)weight_data.data(), + std::vector packed_weights(packed_weights_size( + n, k, group_size, weight_nbit, has_weight_zeros, has_bias, nr, kr, sr)); + pack_weights( + (void*)packed_weights.data(), n, k, group_size, test_case.weight_qvals.data(), test_case.weight_scales.data(), - /*weight_zeros=*/weight_zeros_ptr, - bias_ptr); + has_weight_zeros ? test_case.weight_zeros.data() : nullptr, + has_bias ? test_case.bias.data() : nullptr, + nr, + kr, + sr); std::vector output(m * n); - kernel( + kernel_1x4x16_f32_neondot( output.data(), /*output_m_stride=*/n, m, n, k, group_size, - weight_data.data(), - activation_data.data(), + packed_weights.data(), + packed_activations.data(), /*clamp_min=*/test_case.clamp_min, /*clamp_max=*/test_case.clamp_max, has_weight_zeros, @@ -203,69 +168,19 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot } } -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, - Standard) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); -} - -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, - HasWeightZeros) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - 4 /*weight_nbit*/, - true /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); -} - -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, - HasBias) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); -} - -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, - HasClamp) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); -} - -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, - NLessThan4) { - for (int n = 1; n < 4; n++) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/n, /*group_size=*/16); - } -} - -template -void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot( +template +void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16( int m, int k, int n, - int group_size) { + int group_size, + bool has_bias, + bool has_clamp) { + constexpr int mr = 1; + constexpr int nr = 8; + constexpr int kr = 16; + constexpr int sr = 2; + auto test_case = torchao:: channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( m, @@ -278,48 +193,46 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot has_clamp); using namespace torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; + channelwise_8bit_activation_groupwise_lowbit_weight; - std::vector activation_data( - activation_data_size(m, k, group_size, has_weight_zeros)); - prepare_activation_data( - (void*)activation_data.data(), + std::vector packed_activations( + packed_activations_size(m, k, group_size, has_weight_zeros, mr, kr, sr)); + pack_activations( + (void*)packed_activations.data(), m, k, group_size, test_case.activations.data(), - has_weight_zeros); + has_weight_zeros, + mr, + kr, + sr); - std::vector weight_data(weight_data_size( - n, k, group_size, has_weight_zeros, has_bias)); - int8_t* weight_zeros_ptr = nullptr; - if (has_weight_zeros) { - weight_zeros_ptr = test_case.weight_zeros.data(); - } - float* bias_ptr = nullptr; - if (has_bias) { - bias_ptr = test_case.bias.data(); - } - prepare_weight_data( - (void*)weight_data.data(), + std::vector packed_weights(packed_weights_size( + n, k, group_size, weight_nbit, has_weight_zeros, has_bias, nr, kr, sr)); + pack_weights( + (void*)packed_weights.data(), n, k, group_size, test_case.weight_qvals.data(), test_case.weight_scales.data(), - /*weight_zeros=*/weight_zeros_ptr, - bias_ptr); + has_weight_zeros ? test_case.weight_zeros.data() : nullptr, + has_bias ? test_case.bias.data() : nullptr, + nr, + kr, + sr); std::vector output(m * n); - kernel( + kernel_1x8x16_f32_neondot( output.data(), /*output_m_stride=*/n, m, n, k, group_size, - weight_data.data(), - activation_data.data(), + packed_weights.data(), + packed_activations.data(), /*clamp_min=*/test_case.clamp_min, /*clamp_max=*/test_case.clamp_max, has_weight_zeros, @@ -331,60 +244,173 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot } } -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, - Standard) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); -} +TEST(test_channelwise_8bit_activation_groupwise_lowbit_weight, tile_1x1x32) { + constexpr int weight_nbit = 4; -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, - HasWeightZeros) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - 4 /*weight_nbit*/, - true /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); -} + // Standard + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32< + weight_nbit, + /*has_weight_zeros=*/false>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/32, + /*has_bias=*/false, + /*has_clamp=*/false); -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, - HasBias) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); + // With weight zeros + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32< + weight_nbit, + /*has_weight_zeros=*/true>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/32, + /*has_bias=*/false, + /*has_clamp=*/false); + + // With bias + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32< + weight_nbit, + /*has_weight_zeros=*/false>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/32, + /*has_bias=*/true, + /*has_clamp=*/false); + + // With clamp + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32< + weight_nbit, + /*has_weight_zeros=*/false>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/32, + /*has_bias=*/false, + /*has_clamp=*/true); } -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, - HasClamp) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); +TEST(test_channelwise_8bit_activation_groupwise_lowbit_weight, tile_1x4x16) { + constexpr int weight_nbit = 4; + + // Standard + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16< + weight_nbit, + /*has_weight_zeros=*/false>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/16, + /*has_bias=*/false, + /*has_clamp=*/false); + + // With weight zeros + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16< + weight_nbit, + /*has_weight_zeros=*/true>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/16, + /*has_bias=*/false, + /*has_clamp=*/false); + + // With bias + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16< + weight_nbit, + /*has_weight_zeros=*/false>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/16, + /*has_bias=*/true, + /*has_clamp=*/false); + + // With clamp + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16< + weight_nbit, + /*has_weight_zeros=*/false>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/16, + /*has_bias=*/false, + /*has_clamp=*/true); + + // n less than 4 + for (int n = 1; n < 4; n++) { + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16< + weight_nbit, + /*has_weight_zeros=*/false>( + /*m=*/7, + /*k=*/64, + /*n=*/n, + /*group_size=*/16, + /*has_bias=*/false, + /*has_clamp=*/false); + } } -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, - NLessThan8) { +TEST(test_channelwise_8bit_activation_groupwise_lowbit_weight, tile_1x8x16) { + constexpr int weight_nbit = 4; + + // Standard + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16< + weight_nbit, + /*has_weight_zeros=*/false>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/16, + /*has_bias=*/false, + /*has_clamp=*/false); + + // With weight zeros + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16< + weight_nbit, + /*has_weight_zeros=*/true>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/16, + /*has_bias=*/false, + /*has_clamp=*/false); + + // With bias + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16< + weight_nbit, + /*has_weight_zeros=*/false>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/16, + /*has_bias=*/true, + /*has_clamp=*/false); + + // With clamp + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16< + weight_nbit, + /*has_weight_zeros=*/false>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/16, + /*has_bias=*/false, + /*has_clamp=*/true); + + // n less than 8 for (int n = 1; n < 8; n++) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/n, /*group_size=*/16); + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16< + weight_nbit, + /*has_weight_zeros=*/false>( + /*m=*/7, + /*k=*/64, + /*n=*/n, + /*group_size=*/16, + /*has_bias=*/false, + /*has_clamp=*/false); } } @@ -423,7 +449,10 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_lut( k, group_size, test_case.activations.data(), - has_weight_zeros); + has_weight_zeros, + mr, + kr, + sr); // Define equivalent LUT for affine quantization constexpr int lut_size = (1 << weight_nbit); diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_config.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_config.h new file mode 100644 index 0000000000..1e4a9ef670 --- /dev/null +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_config.h @@ -0,0 +1,238 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include +#include +#include +#include + +namespace torchao::ops::linear_8bit_act_xbit_weight { + +constexpr int kMaxLinearConfigs = 4; +struct UKernelConfig { + // Size of packed_activations buffer + using packed_activations_size_fn_type = size_t (*)( + int m, + int k, + int group_size, + bool has_weight_zeros, + int mr, + int kr, + int sr); + + // Offset in packed_activations buffer for a given m_idx + // m_idx is index in unpacked activations matrix; it will be a multiple of + // m_step + using packed_activations_offset_fn_type = size_t (*)( + int m_idx, + int k, + int group_size, + bool has_weight_zeros, + int mr, + int kr, + int sr); + + // Pack activations into packed_activations buffer + using pack_activations_fn_type = void (*)( + void* packed_activations, + int m, + int k, + int group_size, + const float* activations, + bool has_weight_zeros, + int mr, + int kr, + int sr); + + // Size of packed_weights buffer + using packed_weights_size_fn_type = size_t (*)( + int n, + int k, + int group_size, + int weight_nbit, + bool has_weight_zeros, + bool has_bias, + int nr, + int kr, + int sr); + + // Offset in packed_weights buffer for a given n_idx + // n_inx is index in unpacked weights matrix; it will be a multiple of n_step + using packed_weights_offset_fn_type = size_t (*)( + int n_idx, + int k, + int group_size, + int weight_nbit, + bool has_weight_zeros, + bool has_bias, + int nr, + int kr, + int sr); + + // Pack weights into packed_weights buffer + using pack_weights_fn_type = void (*)( + void* packed_weights, + int n, + int k, + int group_size, + const int8_t* weight_qvals, + const float* weight_scales, + const int8_t* weight_zeros, + const float* bias, + int nr, + int kr, + int sr); + + // Run matmul kernel + using kernel_fn_type = void (*)( + float* output, + int output_m_stride, + int m, + int n, + int k, + int group_size, + const void* packed_weights, + const void* packed_activations, + float clamp_min, + float clamp_max, + bool has_weight_zeros, + bool has_bias, + bool has_clamp); + + struct linear_config_type { + int m_step{0}; // m_idx will be a multiple of this + int mr{0}; + packed_activations_size_fn_type packed_activations_size{nullptr}; + packed_activations_offset_fn_type packed_activations_offset{nullptr}; + pack_activations_fn_type pack_activations{nullptr}; + kernel_fn_type kernel{nullptr}; + }; + + // preferred_alignment for packed_activations and packed_weights + // Integration surfaces are not required to respect this alignment, and the + // kernel must behave correctly no matter how buffers are aligned + size_t preferred_alignment{0}; + int n_step{0}; // n_idx will be a multiple of this + int nr{0}; + int kr{0}; + int sr{0}; + int weight_nbit{0}; + bool has_weight_zeros{false}; + bool has_bias{false}; + packed_weights_size_fn_type packed_weights_size{nullptr}; + packed_weights_offset_fn_type packed_weights_offset{nullptr}; + pack_weights_fn_type pack_weights{nullptr}; + + // linear_configs must be sorted in ascending m_step + std::array linear_configs; + + static UKernelConfig make( + size_t preferred_alignment, + int n_step, + int nr, + int kr, + int sr, + int weight_nbit, + bool has_weight_zeros, + bool has_bias, + packed_weights_size_fn_type packed_weights_size, + packed_weights_offset_fn_type packed_weights_offset, + pack_weights_fn_type pack_weights, + std::array linear_configs); + + inline void validate() const { + TORCHAO_CHECK(preferred_alignment >= 1, "preferred_alignment must be >= 1"); + TORCHAO_CHECK(n_step >= 1, "n_step must be >= 1"); + TORCHAO_CHECK(nr >= 1, "nr must be >= 1"); + TORCHAO_CHECK(kr >= 1, "kr must be >= 1"); + TORCHAO_CHECK(sr >= 1, "sr must be >= 1"); + TORCHAO_CHECK(weight_nbit >= 1, "weight_nbit must be >= 1"); + TORCHAO_CHECK( + packed_weights_size != nullptr, "packed_weights_size must be set"); + TORCHAO_CHECK( + packed_weights_offset != nullptr, "packed_weights_offset must be set"); + TORCHAO_CHECK(pack_weights != nullptr, "pack_weights must be set"); + + bool linear_configs_set = true; // first linear config must be set + for (int i = 0; i < linear_configs.size(); i++) { + if (linear_configs_set) { + TORCHAO_CHECK( + linear_configs[i].m_step >= 1, + "linear_configs[i].m_step must be >= 1"); + TORCHAO_CHECK( + linear_configs[i].mr >= 1, "linear_configs[i].mr must be >= 1"); + TORCHAO_CHECK( + linear_configs[i].packed_activations_size != nullptr, + "linear_configs[i].packed_activations_size must be set"); + TORCHAO_CHECK( + linear_configs[i].packed_activations_offset != nullptr, + "linear_configs[i].packed_activations_offset must be set"); + TORCHAO_CHECK( + linear_configs[i].pack_activations != nullptr, + "linear_configs[i].pack_activations must be set"); + TORCHAO_CHECK( + linear_configs[i].kernel != nullptr, + "linear_configs[i].kernel must be set"); + if (i >= 1) { + TORCHAO_CHECK( + linear_configs[i - 1].m_step < linear_configs[i].m_step, + "set linear_configs must be increasing in m_step"); + } + if (i + 1 < linear_configs.size()) { + linear_configs_set = (linear_configs[i + 1].m_step >= 1); + } + } + } + } + + inline int select_linear_config_idx(int m) const { + assert(m >= 1); + assert(linear_configs[0].m_step >= 1); + + int i = 0; + while (i + 1 < linear_configs.size() && linear_configs[i + 1].m_step >= 1 && + linear_configs[i + 1].m_step <= m) { + assert(linear_configs[i].m_step < linear_configs[i + 1].m_step); + i++; + } + + assert(i < linear_configs.size()); + assert(linear_configs[i].m_step >= 1); + assert(i == 0 || linear_configs[i].m_step <= m); + return i; + } +}; + +inline UKernelConfig UKernelConfig::make( + size_t preferred_alignment, + int n_step, + int nr, + int kr, + int sr, + int weight_nbit, + bool has_weight_zeros, + bool has_bias, + packed_weights_size_fn_type packed_weights_size, + packed_weights_offset_fn_type packed_weights_offset, + pack_weights_fn_type pack_weights, + std::array linear_configs) { + return UKernelConfig{ + preferred_alignment, + n_step, + nr, + kr, + sr, + weight_nbit, + has_weight_zeros, + has_bias, + packed_weights_size, + packed_weights_offset, + pack_weights, + std::move(linear_configs)}; +} + +} // namespace torchao::ops::linear_8bit_act_xbit_weight diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h index e960a918d8..719c2e01e4 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h @@ -6,11 +6,11 @@ #pragma once #include -#include +#include #include #if defined(TORCHAO_BUILD_CPU_AARCH64) -#include +#include #endif // TORCHAO_BUILD_CPU_AARCH64 #include @@ -50,6 +50,7 @@ struct UKernelConfigRegistrationTable { throw std::runtime_error( "UKernelConfig is already registered for this format"); } + config.validate(); registration_table_[key] = config; } std::optional get_ukernel_config( @@ -95,94 +96,90 @@ void register_ukernel_config_universal( torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal, weight_nbit); + namespace kernel = torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight; + + constexpr bool has_lut = false; + int preferred_alignment = 16; + if (format.nr == 8 && format.kr == 16 && format.sr == 2) { + constexpr int n_step = 8; + constexpr int nr = 8; + constexpr int kr = 16; + constexpr int sr = 2; + constexpr int mr = 1; + constexpr int m_step = 1; + #if defined(TORCHAO_BUILD_CPU_AARCH64) if (cpuinfo_has_arm_neon_dot()) { - log_registration(format, "universal"); - namespace kernel = torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; + log_registration(format, "universal: kernel_1x8x16_f32_neondot"); + auto uk = UKernelConfig::make( + preferred_alignment, + n_step, + nr, + kr, + sr, + weight_nbit, + format.has_weight_zeros, + format.has_bias, + &kernel::packed_weights_size, + &kernel::packed_weights_offset, + &kernel::pack_weights, + /*linear_configs*/ {}); if (format.has_weight_zeros) { constexpr bool has_weight_zeros = true; - table.register_ukernel_config( - format, - uarch, - UKernelConfig{ - /*preferred_alignment*/ 16, - /*nr*/ 8, - /*weight_packing_config*/ - {/*weight_data_size_fn*/ - &kernel::weight_data_size, - /*prepare_weight_data_fn*/ - &kernel::prepare_weight_data}, - /*linear_configs*/ - {{{/*mr*/ 1, - /*activation_data_size_fn*/ - &kernel::activation_data_size, - /*prepare_activation_data_fn*/ - &kernel::prepare_activation_data, - /*kernel*/ - &kernel::kernel}}}}); + uk.linear_configs[0] = UKernelConfig::linear_config_type( + {m_step, + mr, + &kernel::packed_activations_size, + &kernel::packed_activations_offset, + &kernel::pack_activations, + &kernel::kernel_1x8x16_f32_neondot< + weight_nbit, + has_weight_zeros, + has_lut>}); + + table.register_ukernel_config(format, uarch, std::move(uk)); + return; } else { constexpr bool has_weight_zeros = false; - table.register_ukernel_config( - format, - uarch, - UKernelConfig{ - /*preferred_alignment*/ 16, - /*nr*/ 8, - /*weight_packing_config*/ - {/*weight_data_size_fn*/ - &kernel::weight_data_size, - /*prepare_weight_data_fn*/ - &kernel::prepare_weight_data}, - /*linear_configs*/ - {{{/*mr*/ 1, - /*activation_data_size_fn*/ - &kernel::activation_data_size, - /*prepare_activation_data_fn*/ - &kernel::prepare_activation_data, - /*kernel*/ - &kernel::kernel}}}}); + uk.linear_configs[0] = UKernelConfig::linear_config_type( + {m_step, + mr, + &kernel::packed_activations_size, + &kernel::packed_activations_offset, + &kernel::pack_activations, + &kernel::kernel_1x8x16_f32_neondot< + weight_nbit, + has_weight_zeros, + has_lut>}); + + table.register_ukernel_config(format, uarch, std::move(uk)); + return; } - return; } #endif // TORCHAO_BUILD_CPU_AARCH64 } } #if defined(TORCHAO_ENABLE_KLEIDI) -template < - typename kernel_struct, - int m_step, - int mr, - int n_step, - int nr, - int kr, - int sr> -UKernelConfig::linear_config_type get_linear_config_kleidi() { +template +UKernelConfig::linear_config_type +get_linear_config_kleidi(int n_step, int nr, int kr, int sr) { namespace op = torchao::kernels::cpu::aarch64::kleidi:: kai_matmul_clamp_f32_qai8dxp_qsi4c32p; - assert(m_step == kernel_struct::get_ukernel().get_m_step()); - assert(mr == kernel_struct::get_ukernel().get_mr()); assert(n_step == kernel_struct::get_ukernel().get_n_step()); assert(nr == kernel_struct::get_ukernel().get_nr()); assert(kr == kernel_struct::get_ukernel().get_kr()); assert(sr == kernel_struct::get_ukernel().get_sr()); - return UKernelConfig::linear_config_type{ - /*mr*/ m_step, - /*activation_data_size_fn*/ &op::activation_data_size, - /*prepare_activation_data_fn*/ &op::prepare_activation_data, - /*kernel*/ &kernel_struct::kernel}; -} - -template -UKernelConfig::weight_packing_config_type get_weight_packing_config_kleidi() { - namespace op = torchao::kernels::cpu::aarch64::kleidi:: - kai_matmul_clamp_f32_qai8dxp_qsi4c32p; - return UKernelConfig::weight_packing_config_type( - {/*weight_data_size_fn*/ &op::weight_data_size, - /*prepare_weight_data_fn*/ &op::prepare_weight_data}); + return UKernelConfig::linear_config_type( + {static_cast(kernel_struct::get_ukernel().get_m_step()), + static_cast(kernel_struct::get_ukernel().get_mr()), + &op::packed_activations_size, + &op::packed_activations_offset, + &op::pack_activations, + &kernel_struct::kernel}); } template @@ -197,89 +194,62 @@ void register_ukernel_config_kleidi( namespace op = torchao::kernels::cpu::aarch64::kleidi:: kai_matmul_clamp_f32_qai8dxp_qsi4c32p; + auto uk = UKernelConfig::make( + /*preferred_alignment*/ op::get_preferred_alignement(), + /*n_step*/ format.nr, + format.nr, + format.kr, + format.sr, + format.weight_nbit, + format.has_weight_zeros, + format.has_bias, + &op::packed_weights_size, + &op::packed_weights_offset, + &op::pack_weights, + {} /*linear_configs*/); + if (format.nr == 8 && format.kr == 16 && format.sr == 2) { - constexpr int nr = 8; - constexpr int kr = 16; - constexpr int sr = 2; + uk.n_step = 8; + #if defined(TORCHAO_ENABLE_ARM_I8MM) if (cpuinfo_has_arm_i8mm()) { - constexpr int n_step = 8; + /*m_step=4*/ + uk.linear_configs[0] = get_linear_config_kleidi< + op::matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm>( + uk.n_step, uk.nr, uk.kr, uk.sr); log_registration( format, "kleidiai: matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm"); - table.register_ukernel_config( - format, - uarch, - UKernelConfig{ - /*preferred_alignment*/ op::get_preferred_alignement(), - /*nr*/ n_step, - /*weight_packing_config*/ - get_weight_packing_config_kleidi(), - /*linear_configs*/ - {{get_linear_config_kleidi< - op::matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - /*m_step*/ 4, - /*mr*/ 4, - n_step, - nr, - kr, - sr>()}}}); + table.register_ukernel_config(format, uarch, std::move(uk)); return; } #endif // TORCHAO_ENABLE_ARM_I8MM if (cpuinfo_has_arm_neon_dot()) { - constexpr int n_step = 8; log_registration( format, "kleidiai: matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod"); - table.register_ukernel_config( - format, - uarch, - UKernelConfig{ - /*preferred_alignment*/ op::get_preferred_alignement(), - /*nr*/ n_step, - /*weight_packing_config*/ - get_weight_packing_config_kleidi(), - /*linear_configs*/ - {{get_linear_config_kleidi< - op::matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - /*m_step*/ 1, - /*mr*/ 1, - n_step, - nr, - kr, - sr>()}}}); + /*m_step=1*/ + uk.linear_configs[0] = get_linear_config_kleidi< + op::matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod>( + uk.n_step, uk.nr, uk.kr, uk.sr); + table.register_ukernel_config(format, uarch, std::move(uk)); return; } } if (format.nr == 4 && format.kr == 16 && format.sr == 2) { - constexpr int nr = 4; - constexpr int kr = 16; - constexpr int sr = 2; + uk.n_step = 4; if (cpuinfo_has_arm_neon_dot()) { - constexpr int n_step = 4; + /*m_step=1*/ + uk.linear_configs[0] = get_linear_config_kleidi< + op::matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod>( + uk.n_step, uk.nr, uk.kr, uk.sr); + log_registration( format, "kleidiai: matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod"); - table.register_ukernel_config( - format, - uarch, - UKernelConfig{ - /*preferred_alignment*/ op::get_preferred_alignement(), - /*nr*/ n_step, - /*weight_packing_config*/ - get_weight_packing_config_kleidi(), - /*linear_configs*/ - {{get_linear_config_kleidi< - op::matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /*m_step*/ 1, - /*mr*/ 1, - n_step, - nr, - kr, - sr>()}}}); + table.register_ukernel_config(format, uarch, std::move(uk)); return; } } @@ -361,7 +331,7 @@ PackedWeightsFormat select_packed_weights_format( torchao::ops::PackedWeightsType::kleidi_ai, weight_nbit, has_weight_zeros, - /*has_bias*/ true, + has_bias, /*nr*/ 8, /*kr*/ 16, /*sr*/ 2); diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp index 0421e6a25f..6929e6e4a4 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp @@ -7,43 +7,19 @@ #include #include #include +#include #include #include #include #include +#include namespace torchao::ops::linear_8bit_act_xbit_weight { -PackWeightDataTilingParams get_default_pack_weight_data_tiling_params( - const UKernelConfig& ukernel_config, - int n, - int target_panels_per_thread) { - TORCHAO_CHECK(n >= 1, "n must be >= 1"); - TORCHAO_CHECK( - target_panels_per_thread >= 1, "target_panels_per_thread must be >= 1"); - - PackWeightDataTilingParams tiling_params; - int nr = ukernel_config.nr; - int num_threads = torchao::get_num_threads(); - int numerator = n; - int denominator = num_threads * target_panels_per_thread; - - // Set nc = ceil(numerator / denominator) - int nc = (numerator + denominator - 1) / denominator; - assert(nc >= 1); - - // Replace nc with the next number nr divides - nc = ((nc + nr - 1) / nr) * nr; - tiling_params.nc_by_nr = nc / nr; - - return tiling_params; -} - -void pack_weight_data_operator( - const UKernelConfig& ukernel_config, - const PackWeightDataTilingParams& tiling_params, +void pack_weights_operator( + const UKernelConfig& uk, // Outputs - void* weight_data, + void* packed_weights, // Inputs int n, int k, @@ -54,12 +30,14 @@ void pack_weight_data_operator( const float* bias) { TORCHAO_CHECK(group_size % 16 == 0, "group_size must be a multiple of 16"); TORCHAO_CHECK(k % group_size == 0, "group_size must divide k"); + TORCHAO_CHECK( + uk.has_bias == (bias != nullptr), "bias/has_bias is inconsistent"); + TORCHAO_CHECK( + uk.has_weight_zeros == (weight_zeros != nullptr), + "weight_zeros/has_weight_zeros is inconsistent"); - bool has_weight_zeros = (weight_zeros != nullptr); - bool has_bias = (bias != nullptr); - - int nr = ukernel_config.nr; - int nc = std::min(n, tiling_params.nc_by_nr * ukernel_config.nr); + int n_step = uk.n_step; + int nc = std::min(n, n_step); int num_nc_panels = (n + nc - 1) / nc; torchao::parallel_1d(0, num_nc_panels, [&](int64_t idx) { @@ -67,50 +45,53 @@ void pack_weight_data_operator( int n_idx = nc_tile_idx * nc; int nc_tile_size = std::min(nc, n - n_idx); - int weight_data_offset = (n_idx / nr) * - ukernel_config.weight_packing_config.weight_data_size_fn( - nr, k, group_size, has_weight_zeros, has_bias); + auto packed_weights_offset = uk.packed_weights_offset( + n_idx, + k, + group_size, + uk.weight_nbit, + uk.has_weight_zeros, + uk.has_bias, + uk.nr, + uk.kr, + uk.sr); + int weight_qvals_offset = n_idx * k; int weight_scales_and_zeros_offset = (n_idx * k / group_size); - - const int8_t* weight_zeros_ptr = nullptr; - if (weight_zeros != nullptr) { - weight_zeros_ptr = weight_zeros + weight_scales_and_zeros_offset; - } - const float* bias_ptr = nullptr; - if (bias != nullptr) { - bias_ptr = bias + n_idx; - } - - ukernel_config.weight_packing_config.prepare_weight_data_fn( - (char*)weight_data + weight_data_offset, + uk.pack_weights( + (char*)packed_weights + packed_weights_offset, /*n=*/nc_tile_size, k, group_size, weight_qvals + weight_qvals_offset, weight_scales + weight_scales_and_zeros_offset, - weight_zeros_ptr, - bias_ptr); + (weight_zeros == nullptr) + ? nullptr + : (weight_zeros + weight_scales_and_zeros_offset), + (bias == nullptr) ? nullptr : (bias + n_idx), + uk.nr, + uk.kr, + uk.sr); }); } -// This default mimics XNNPACK behavior if target_tiles_per_thread = 5 -LinearTilingParams get_default_linear_tiling_params( - const UKernelConfig& ukernel_config, +LinearTilingParams LinearTilingParams::from_target_tiles_per_thread( int m, + int m_step, int n, + int n_step, int target_tiles_per_thread) { TORCHAO_CHECK(m >= 1, "m must be >= 1"); + TORCHAO_CHECK(m_step >= 1, "m_step must be >= 1"); + TORCHAO_CHECK(n >= 1, "n must be >= 1"); + TORCHAO_CHECK(n_step >= 1, "n_step must be >= 1"); TORCHAO_CHECK( target_tiles_per_thread >= 1, "target_tiles_per_thread must be >= 1"); - - LinearTilingParams tiling_params; auto num_threads = torchao::get_num_threads(); TORCHAO_CHECK(num_threads >= 1, "num_threads must be >= 1"); - tiling_params.mc_by_mr = 1; - int mc = tiling_params.mc_by_mr * ukernel_config.linear_configs[0].mr; + int mc = m_step; int num_mc_panels = (m + mc - 1) / mc; int numerator = n * num_mc_panels; @@ -120,50 +101,25 @@ LinearTilingParams get_default_linear_tiling_params( int nc = (numerator + denominator - 1) / denominator; assert(nc >= 1); - // Replace nc with next number nr divides - int nr = ukernel_config.nr; - nc = ((nc + nr - 1) / nr) * nr; - assert(nc % nr == 0); - tiling_params.nc_by_nr = nc / nr; + // Replace nc with next number n_step divides + nc = ((nc + n_step - 1) / n_step) * n_step; - assert(tiling_params.mc_by_mr >= 1); - assert(tiling_params.nc_by_nr >= 1); - return tiling_params; -} - -namespace internal { + // Clamp mc, nc to be no larger than m, n + mc = std::min(m, mc); + nc = std::min(n, nc); -inline size_t -get_activation_data_buffer_size_with_tile_schedule_policy_single_mc_parallel_nc( - const UKernelConfig& ukernel_config, - const LinearTilingParams& tiling_params, - int m, - int k, - int group_size, - bool has_weight_zeros) { - return ukernel_config.linear_configs[0].activation_data_size_fn( - tiling_params.mc_by_mr * ukernel_config.linear_configs[0].mr, - k, - group_size, - has_weight_zeros); -} + assert((mc == m) || (mc % m_step == 0)); + assert((nc == n) || (nc % n_step == 0)); -inline size_t -get_activation_data_buffer_size_with_tile_schedule_policy_parallel_mc_parallel_nc( - const UKernelConfig& ukernel_config, - const LinearTilingParams& tiling_params, - int m, - int k, - int group_size, - bool has_weight_zeros) { - return ukernel_config.linear_configs[0].activation_data_size_fn( - m, k, group_size, has_weight_zeros); + LinearTilingParams tiling_params; + tiling_params.mc = mc; + tiling_params.nc = nc; + return tiling_params; } -inline void linear_operator_with_tile_schedule_policy_single_mc_parallel_nc( - const UKernelConfig& ukernel_config, - const LinearTilingParams& tiling_params, - char* activation_data_buffer, +void linear_operator( + const UKernelConfig& uk, + const std::optional& tiling_params, // Outputs float* output, // Inputs @@ -171,237 +127,101 @@ inline void linear_operator_with_tile_schedule_policy_single_mc_parallel_nc( int n, int k, int group_size, - const void* weight_data, + const void* packed_weights, const float* activations, - // Ignored if has_clamp = false + bool has_clamp, float clamp_min, - float clamp_max, - bool has_weight_zeros, - bool has_bias, - bool has_clamp) { - int nr = ukernel_config.nr; - int mc = - std::min(m, tiling_params.mc_by_mr * ukernel_config.linear_configs[0].mr); - int nc = std::min(n, tiling_params.nc_by_nr * nr); + float clamp_max) { + TORCHAO_CHECK(group_size % 16 == 0, "group_size must be a multiple of 16"); + TORCHAO_CHECK(k % group_size == 0, "group_size must divide k"); + + // Select linear config based on m + int linear_config_idx = uk.select_linear_config_idx(m); + auto& linear_config = uk.linear_configs[linear_config_idx]; + int n_step = uk.n_step; + int m_step = linear_config.m_step; + + // Choose tiling params + int mc, nc; + if (tiling_params.has_value()) { + mc = tiling_params->mc; + nc = tiling_params->nc; + } else { + auto params = LinearTilingParams::from_target_tiles_per_thread( + m, + m_step, + n, + n_step, + /*target_tiles_per_thread=*/5); + mc = params.mc; + nc = params.nc; + } + TORCHAO_CHECK(mc >= 1, "mc must be >= 1"); + TORCHAO_CHECK(nc >= 1, "nc must be >= 1"); + TORCHAO_CHECK( + (mc == m) || (mc % m_step == 0), + "mc from tiling_params must be m or a multiple of m_step"); + TORCHAO_CHECK( + (nc == n) || (nc % n_step == 0), + "nc from tiling_params must be n or a multiple of n_step"); + int num_mc_panels = (m + mc - 1) / mc; int num_nc_panels = (n + nc - 1) / nc; - size_t weight_data_size = - ukernel_config.weight_packing_config.weight_data_size_fn( - nr, k, group_size, has_weight_zeros, has_bias); + auto packed_activations_size = linear_config.packed_activations_size( + mc, k, group_size, uk.has_weight_zeros, linear_config.mr, uk.kr, uk.sr); + + auto packed_activations = torchao::make_aligned_byte_ptr( + uk.preferred_alignment, packed_activations_size); for (int mc_tile_idx = 0; mc_tile_idx < num_mc_panels; mc_tile_idx++) { int m_idx = mc_tile_idx * mc; int mc_tile_size = std::min(mc, m - m_idx); int activations_offset = m_idx * k; - ukernel_config.linear_configs[0].prepare_activation_data_fn( - activation_data_buffer, + + linear_config.pack_activations( + packed_activations.get(), /*m=*/mc_tile_size, k, group_size, activations + activations_offset, - has_weight_zeros); + uk.has_weight_zeros, + linear_config.mr, + uk.kr, + uk.sr); torchao::parallel_1d(0, num_nc_panels, [&](int64_t idx) { int nc_tile_idx = idx; int n_idx = nc_tile_idx * nc; int nc_tile_size = std::min(nc, n - n_idx); - int output_offset = m_idx * n + n_idx; - int weight_data_offset = (n_idx / nr) * weight_data_size; - ukernel_config.linear_configs[0].kernel_fn( + auto packed_weights_offset = uk.packed_weights_offset( + n_idx, + k, + group_size, + uk.weight_nbit, + uk.has_weight_zeros, + uk.has_bias, + uk.nr, + uk.kr, + uk.sr); + + linear_config.kernel( output + output_offset, /*output_m_stride=*/n, /*m=*/mc_tile_size, /*n=*/nc_tile_size, k, group_size, - /*weight_data=*/(char*)weight_data + weight_data_offset, - /*activation_data=*/activation_data_buffer, + /*packed_weights=*/(char*)packed_weights + packed_weights_offset, + /*packed_activations=*/packed_activations.get(), clamp_min, clamp_max, - has_weight_zeros, - has_bias, + uk.has_weight_zeros, + uk.has_bias, has_clamp); }); } } -inline void linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc( - const UKernelConfig& ukernel_config, - const LinearTilingParams& tiling_params, - char* activation_data_buffer, - // Outputs - float* output, - // Inputs - int m, - int n, - int k, - int group_size, - const void* weight_data, - const float* activations, - float clamp_min, - float clamp_max, - bool has_weight_zeros, - bool has_bias, - bool has_clamp) { - int mr = ukernel_config.linear_configs[0].mr; - int nr = ukernel_config.nr; - int mc = std::min(m, tiling_params.mc_by_mr * mr); - int nc = std::min(n, tiling_params.nc_by_nr * nr); - int num_mc_panels = (m + mc - 1) / mc; - int num_nc_panels = (n + nc - 1) / nc; - - size_t weight_data_size = - ukernel_config.weight_packing_config.weight_data_size_fn( - nr, k, group_size, has_weight_zeros, has_bias); - size_t activation_data_size = - ukernel_config.linear_configs[0].activation_data_size_fn( - mr, k, group_size, has_weight_zeros); - - torchao::parallel_1d(0, num_mc_panels, [&](int64_t idx) { - int mc_tile_idx = idx; - int m_idx = mc_tile_idx * mc; - int mc_tile_size = std::min(mc, m - m_idx); - int activations_offset = m_idx * k; - int activation_data_offset = (m_idx / mr) * activation_data_size; - - ukernel_config.linear_configs[0].prepare_activation_data_fn( - activation_data_buffer + activation_data_offset, - /*m=*/mc_tile_size, - k, - group_size, - activations + activations_offset, - has_weight_zeros); - }); - - torchao::parallel_1d(0, num_mc_panels * num_nc_panels, [&](int64_t idx) { - int mc_tile_idx = idx / num_nc_panels; - int m_idx = mc_tile_idx * mc; - int mc_tile_size = std::min(mc, m - m_idx); - - int nc_tile_idx = idx % num_nc_panels; - int n_idx = nc_tile_idx * nc; - int nc_tile_size = std::min(nc, n - n_idx); - - int activation_data_offset = (m_idx / mr) * activation_data_size; - int output_offset = m_idx * n + n_idx; - int weight_data_offset = (n_idx / nr) * weight_data_size; - - ukernel_config.linear_configs[0].kernel_fn( - output + output_offset, - /*output_m_stride=*/n, - /*m=*/mc_tile_size, - /*n=*/nc_tile_size, - k, - group_size, - /*weight_data=*/(char*)weight_data + weight_data_offset, - /*activation_data=*/activation_data_buffer + activation_data_offset, - clamp_min, - clamp_max, - has_weight_zeros, - has_bias, - has_clamp); - }); -} -} // namespace internal - -void linear_operator( - const UKernelConfig& ukernel_config, - const LinearTilingParams& tiling_params, - LinearTileSchedulingPolicy scheduling_policy, - char* activation_data_buffer, - // Outputs - float* output, - // Inputs - int m, - int n, - int k, - int group_size, - const void* weight_data, - const float* activations, - // Ignored if has_clamp = false - float clamp_min, - float clamp_max, - bool has_weight_zeros, - bool has_bias, - bool has_clamp) { - TORCHAO_CHECK(group_size % 16 == 0, "group_size must be a multiple of 16"); - TORCHAO_CHECK(k % group_size == 0, "group_size must divide k"); - switch (scheduling_policy) { - case LinearTileSchedulingPolicy::single_mc_parallel_nc: - internal::linear_operator_with_tile_schedule_policy_single_mc_parallel_nc( - ukernel_config, - tiling_params, - activation_data_buffer, - output, - m, - n, - k, - group_size, - weight_data, - activations, - clamp_min, - clamp_max, - has_weight_zeros, - has_bias, - has_clamp); - break; - case LinearTileSchedulingPolicy::parallel_mc_parallel_nc: - internal:: - linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc( - ukernel_config, - tiling_params, - activation_data_buffer, - output, - m, - n, - k, - group_size, - weight_data, - activations, - clamp_min, - clamp_max, - has_weight_zeros, - has_bias, - has_clamp); - break; - default: - TORCHAO_CHECK(false, "Unimplemented LinearTileSchedulingPolicy"); - } -} - -size_t get_activation_data_buffer_size( - const UKernelConfig& ukernel_config, - const LinearTilingParams& tiling_params, - LinearTileSchedulingPolicy scheduling_policy, - int m, - int k, - int group_size, - bool has_weight_zeros) { - switch (scheduling_policy) { - case LinearTileSchedulingPolicy::single_mc_parallel_nc: - return internal:: - get_activation_data_buffer_size_with_tile_schedule_policy_single_mc_parallel_nc( - ukernel_config, - tiling_params, - m, - k, - group_size, - has_weight_zeros); - case LinearTileSchedulingPolicy::parallel_mc_parallel_nc: - return internal:: - get_activation_data_buffer_size_with_tile_schedule_policy_parallel_mc_parallel_nc( - ukernel_config, - tiling_params, - m, - k, - group_size, - has_weight_zeros); - default: - TORCHAO_CHECK(false, "Unimplemented LinearTileSchedulingPolicy"); - } -} - -} // namespace - // torchao::ops::linear_8bit_act_xbit_weight +} // namespace torchao::ops::linear_8bit_act_xbit_weight diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h index dba0adb32d..accc5be5a1 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h @@ -7,102 +7,17 @@ #pragma once #include #include +#include #include #include +#include namespace torchao::ops::linear_8bit_act_xbit_weight { -struct UKernelConfig { - using activation_data_size_fn_type = - size_t (*)(int m, int k, int group_size, bool has_weight_zeros); - using prepare_activation_data_fn_type = void (*)( - void* activation_data, - int m, - int k, - int group_size, - const float* activations, - bool has_weight_zeros); - using weight_data_size_fn_type = size_t (*)( - int n, - int k, - int group_size, - bool has_weight_zeros, - bool has_bias); - using prepare_weight_data_fn_type = void (*)( - void* weight_data, - int n, - int k, - int group_size, - const int8_t* weight_qvals, - const float* weight_scales, - const int8_t* weight_zeros, - const float* bias); - using kernel_fn_type = void (*)( - float* output, - int output_m_stride, - int m, - int n, - int k, - int group_size, - const void* weight_data, - const void* activation_data, - float clamp_min, - float clamp_max, - bool has_weight_zeros, - bool has_bias, - bool has_clamp); - - struct weight_packing_config_type { - weight_data_size_fn_type weight_data_size_fn{nullptr}; - prepare_weight_data_fn_type prepare_weight_data_fn{nullptr}; - }; - struct linear_config_type { - int mr{0}; - activation_data_size_fn_type activation_data_size_fn{nullptr}; - prepare_activation_data_fn_type prepare_activation_data_fn{nullptr}; - kernel_fn_type kernel_fn{nullptr}; - }; - - // preferred_alignment for activation and weight data - // Integration surfaces are not required to respect this alignment, and the - // ukernel must behave correctly no matter how buffers are aligned - size_t preferred_alignment{0}; - int nr{0}; - weight_packing_config_type weight_packing_config; - std::array linear_configs; -}; - -// Pack weight functions -struct PackWeightDataTilingParams { - int nc_by_nr{1}; -}; - -PackWeightDataTilingParams get_default_pack_weight_data_tiling_params( - const UKernelConfig& ukernel_config, - int n, - int target_panels_per_thread = 1); - -inline size_t get_packed_weight_data_size( - const UKernelConfig& ukernel_config, - int n, - int k, - int group_size, - bool has_weight_zeros, - bool has_bias) { - return ukernel_config.weight_packing_config.weight_data_size_fn( - n, k, group_size, has_weight_zeros, has_bias); -} - -inline size_t get_preferred_packed_weight_data_alignment( - const UKernelConfig& ukernel_config) { - return ukernel_config.preferred_alignment; -} - -void pack_weight_data_operator( - const UKernelConfig& ukernel_config, - const PackWeightDataTilingParams& tiling_params, +void pack_weights_operator( + const UKernelConfig& uk, // Outputs - void* weight_data, + void* packed_weights, // Inputs int n, int k, @@ -114,40 +29,23 @@ void pack_weight_data_operator( // Linear functions struct LinearTilingParams { - int mc_by_mr{1}; - int nc_by_nr{1}; -}; - -LinearTilingParams get_default_linear_tiling_params( - const UKernelConfig& ukernel_config, - int m, - int n, - int target_tiles_per_thread = 5); + int mc{0}; + int nc{0}; -enum class LinearTileSchedulingPolicy { - single_mc_parallel_nc, - parallel_mc_parallel_nc + // Returns LinearTilingParams with mc and nc chosen so that there are + // approximately target_tiles_per_thread tiles per thread. The method + // guarantees 1. mc = m or mc % m_step == 0, and 2. nc = n or nc % n_step == 0 + static LinearTilingParams from_target_tiles_per_thread( + int m, + int m_step, + int n, + int n_step, + int target_tiles_per_thread); }; -size_t get_activation_data_buffer_size( - const UKernelConfig& ukernel_config, - const LinearTilingParams& tiling_params, - LinearTileSchedulingPolicy scheduling_policy, - int m, - int k, - int group_size, - bool has_weight_zeros); - -inline size_t get_preferred_activation_data_buffer_alignment( - const UKernelConfig& ukernel_config) { - return ukernel_config.preferred_alignment; -} - void linear_operator( const UKernelConfig& ukernel_config, - const LinearTilingParams& tiling_params, - LinearTileSchedulingPolicy scheduling_policy, - char* activation_data_buffer, + const std::optional& tiling_params, // Outputs float* output, // Inputs @@ -155,13 +53,11 @@ void linear_operator( int n, int k, int group_size, - const void* weight_data, + const void* packed_weights, const float* activations, + bool has_clamp, float clamp_min, - float clamp_max, - bool has_weight_zeros, - bool has_bias, - bool has_clamp); + float clamp_max); } // namespace // torchao::ops::linear_8bit_act_xbit_weight diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h index 636fc01c64..065a5b0319 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h @@ -69,29 +69,31 @@ Tensor pack_weights_cpu( bias_ptr = bias.value().const_data_ptr(); } - using namespace torchao::ops::linear_8bit_act_xbit_weight; - - auto packed_weights_format = select_packed_weights_format( - target, has_weight_zeros, has_bias); + auto packed_weights_format = + torchao::ops::linear_8bit_act_xbit_weight::select_packed_weights_format< + weight_nbit>(target, has_weight_zeros, has_bias); auto packed_weights_header = packed_weights_format.to_packed_weights_header(); - auto ukernel_config = - select_ukernel_config(packed_weights_header); - - auto pack_weight_tiling_params = get_default_pack_weight_data_tiling_params( - ukernel_config, n, /*target_panels_per_thread=*/1); + auto uk = torchao::ops::linear_8bit_act_xbit_weight::select_ukernel_config< + weight_nbit>(packed_weights_header); + + auto packed_weight_data_size = torchao::ops::PackedWeightsHeader::size() + + uk.packed_weights_size( + n, + k, + group_size, + weight_nbit, + has_weight_zeros, + has_bias, + uk.nr, + uk.kr, + uk.sr); - auto packed_weight_data_size = - torchao::ops::PackedWeightsHeader::size() + - get_packed_weight_data_size( - ukernel_config, n, k, group_size, has_weight_zeros, has_bias); Tensor packed_weights = torch::empty( {static_cast(packed_weight_data_size)}, torch::kInt8); packed_weights_header.write(packed_weights.mutable_data_ptr()); - // TODO: support passing in bias in future - pack_weight_data_operator( - ukernel_config, - pack_weight_tiling_params, + torchao::ops::linear_8bit_act_xbit_weight::pack_weights_operator( + uk, packed_weights.mutable_data_ptr() + torchao::ops::PackedWeightsHeader::size(), n, @@ -122,18 +124,26 @@ Tensor pack_weights_meta( int n = weight_qvals.size(0); int k = weight_qvals.size(1); - using namespace torchao::ops::linear_8bit_act_xbit_weight; - - auto packed_weights_format = select_packed_weights_format( - target, has_weight_zeros, has_bias); - auto ukernel_config = - select_ukernel_config(packed_weights_format); - - auto packed_weight_data_size = - torchao::ops::PackedWeightsHeader::size() + - get_packed_weight_data_size( - ukernel_config, n, k, group_size, has_weight_zeros, has_bias); - auto options = torch::TensorOptions().device(c10::DeviceType::Meta).dtype(torch::kInt8); + auto packed_weights_format = + torchao::ops::linear_8bit_act_xbit_weight::select_packed_weights_format< + weight_nbit>(target, has_weight_zeros, has_bias); + auto uk = torchao::ops::linear_8bit_act_xbit_weight::select_ukernel_config< + weight_nbit>(packed_weights_format); + + auto packed_weight_data_size = torchao::ops::PackedWeightsHeader::size() + + uk.packed_weights_size( + n, + k, + group_size, + weight_nbit, + has_weight_zeros, + has_bias, + uk.nr, + uk.kr, + uk.sr); + + auto options = + torch::TensorOptions().device(c10::DeviceType::Meta).dtype(torch::kInt8); return torch::empty({static_cast(packed_weight_data_size)}, options); } #endif // USE_ATEN @@ -169,8 +179,6 @@ Tensor linear_out_cpu( // Explicit cast from int64_t to int is required for Executorch TORCHAO_RESIZE_TENSOR(out, {(int)m, (int)n}); - using namespace torchao::ops::linear_8bit_act_xbit_weight; - TORCHAO_CHECK(packed_weights.dim() == 1, "packed_weights must be 1D"); #ifdef USE_ATEN TORCHAO_CHECK( @@ -182,36 +190,12 @@ Tensor linear_out_cpu( auto header = torchao::ops::PackedWeightsHeader::read(packed_weights.const_data_ptr()); - auto format = torchao::ops::linear_8bit_act_xbit_weight::PackedWeightsFormat:: - from_packed_weights_header(header); - - auto ukernel_config = select_ukernel_config(header); - - auto linear_tiling_params = get_default_linear_tiling_params( - ukernel_config, - m, - n, - /*target_tiles_per_thread=*/5); - - auto linear_scheduling_policy = - LinearTileSchedulingPolicy::single_mc_parallel_nc; - - auto activation_data_buffer_size = get_activation_data_buffer_size( - ukernel_config, - linear_tiling_params, - linear_scheduling_policy, - m, - k, - group_size, - format.has_weight_zeros); - - std::vector activation_data_buffer(activation_data_buffer_size); + auto uk = torchao::ops::linear_8bit_act_xbit_weight::select_ukernel_config< + weight_nbit>(header); - linear_operator( - ukernel_config, - linear_tiling_params, - linear_scheduling_policy, - activation_data_buffer.data(), + torchao::ops::linear_8bit_act_xbit_weight::linear_operator( + uk, + std::nullopt, out.mutable_data_ptr(), m, n, @@ -220,13 +204,9 @@ Tensor linear_out_cpu( packed_weights.const_data_ptr() + torchao::ops::PackedWeightsHeader::size(), activations.const_data_ptr(), - // Clamp parameters are ignored because config is created from - // has_clamp = false + /*has_clamp=*/false, /*clamp_min=*/0.0, - /*clamp_max=*/0.0, - format.has_weight_zeros, - format.has_bias, - /*has_clamp*/ false); + /*clamp_max=*/0.0); return out; } diff --git a/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp b/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp index caaf8baf74..980228a1a8 100644 --- a/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp +++ b/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp @@ -6,7 +6,9 @@ #include // TODO: move test_utils.h out of aarch64 -#include +#if defined(TORCHAO_BUILD_CPU_AARCH64) +#include +#endif // TORCHAO_BUILD_CPU_AARCH64 #include #include #include @@ -26,23 +28,41 @@ using namespace torchao::ops::linear_8bit_act_xbit_weight; template UKernelConfig get_ukernel_config() { namespace kernel = torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; - return UKernelConfig{ - /*preferred_alignment*/ 16, - /*nr*/ 8, - /*weight_packing_config*/ - {/*weight_data_size_fn*/ - &kernel::weight_data_size, - /*prepare_weight_data_fn*/ - &kernel::prepare_weight_data}, - /*linear_configs*/ - {{{/*mr*/ 1, - /*activation_data_size_fn*/ - &kernel::activation_data_size, - /*prepare_activation_data_fn*/ - &kernel::prepare_activation_data, - /*kernel*/ - &kernel::kernel}}}}; + channelwise_8bit_activation_groupwise_lowbit_weight; + + int preferred_alignment = 16; + int n_step = 8; + constexpr int nr = 8; + constexpr int kr = 16; + constexpr int sr = 2; + constexpr int mr = 1; + int m_step = 1; + constexpr bool has_lut = false; + + auto uk = UKernelConfig::make( + preferred_alignment, + n_step, + nr, + kr, + sr, + weight_nbit, + has_weight_zeros, + has_bias, + &kernel::packed_weights_size, + &kernel::packed_weights_offset, + &kernel::pack_weights, + /*linear_configs*/ {}); + + uk.linear_configs[0] = UKernelConfig::linear_config_type{ + m_step, + mr, + &kernel::packed_activations_size, + &kernel::packed_activations_offset, + &kernel::pack_activations, + &kernel:: + kernel_1x8x16_f32_neondot}; + + return uk; } template < @@ -82,87 +102,68 @@ void test_linear_8bit_act_xbit_weight( auto output = std::vector(m * n); - for (auto linear_scheduling_policy : - {LinearTileSchedulingPolicy::single_mc_parallel_nc, - LinearTileSchedulingPolicy::parallel_mc_parallel_nc}) { - for (auto num_threads : {1, 4, 500}) { - torchao::set_num_threads(num_threads); - EXPECT_EQ(torchao::get_num_threads(), num_threads); - - // Pack weights - auto pack_weight_data_tiling_params = - get_default_pack_weight_data_tiling_params(ukernel_config, n); - auto packed_weight_data_size = get_packed_weight_data_size( - ukernel_config, n, k, group_size, has_weight_zeros, has_bias); - auto preferred_packed_weight_data_alignment = - get_preferred_packed_weight_data_alignment(ukernel_config); - auto packed_weight_data = torchao::make_aligned_byte_ptr( - preferred_packed_weight_data_alignment, packed_weight_data_size); - - int8_t* weight_zeros_ptr = nullptr; - if (has_weight_zeros) { - weight_zeros_ptr = test_case.weight_zeros.data(); - } - float* bias_ptr = nullptr; - if (has_bias) { - bias_ptr = test_case.bias.data(); - } - pack_weight_data_operator( - ukernel_config, - pack_weight_data_tiling_params, - packed_weight_data.get(), - n, - k, - group_size, - test_case.weight_qvals.data(), - test_case.weight_scales.data(), - weight_zeros_ptr, - bias_ptr); - - // Allocate activation buffer - auto linear_tiling_params = - get_default_linear_tiling_params(ukernel_config, m, n); - - auto activation_data_buffer_size = get_activation_data_buffer_size( - ukernel_config, - linear_tiling_params, - linear_scheduling_policy, - m, - k, - group_size, - has_weight_zeros); - auto activation_data_buffer_alignment = - get_preferred_activation_data_buffer_alignment(ukernel_config); - auto activation_data_buffer = torchao::make_aligned_byte_ptr( - activation_data_buffer_alignment, activation_data_buffer_size); - - // Run linear - linear_operator( - ukernel_config, - linear_tiling_params, - linear_scheduling_policy, - activation_data_buffer.get(), - output.data(), - m, - n, - k, - group_size, - packed_weight_data.get(), - test_case.activations.data(), - test_case.clamp_min, - test_case.clamp_max, - has_weight_zeros, - has_bias, - has_clamp); - - // Test correctness - float tol = kTol; - if (has_kleidi) { - tol = kTolKleidiAI; - } - for (int i = 0; i < m * n; i++) { - EXPECT_NEAR(output[i], test_case.expected_output[i], tol); - } + for (auto num_threads : {1, 4, 500}) { + torchao::set_num_threads(num_threads); + EXPECT_EQ(torchao::get_num_threads(), num_threads); + + // Pack weights + auto packed_weight_data_size = ukernel_config.packed_weights_size( + n, + k, + group_size, + weight_nbit, + has_weight_zeros, + has_bias, + ukernel_config.nr, + ukernel_config.kr, + ukernel_config.sr); + auto preferred_packed_weight_data_alignment = + ukernel_config.preferred_alignment; + auto packed_weights = torchao::make_aligned_byte_ptr( + preferred_packed_weight_data_alignment, packed_weight_data_size); + + int8_t* weight_zeros_ptr = nullptr; + if (has_weight_zeros) { + weight_zeros_ptr = test_case.weight_zeros.data(); + } + float* bias_ptr = nullptr; + // kleidi always has bias in these tests + if (has_bias || has_kleidi) { + bias_ptr = test_case.bias.data(); + } + + pack_weights_operator( + ukernel_config, + packed_weights.get(), + n, + k, + group_size, + test_case.weight_qvals.data(), + test_case.weight_scales.data(), + weight_zeros_ptr, + bias_ptr); + + linear_operator( + ukernel_config, + std::nullopt, + output.data(), + m, + n, + k, + group_size, + packed_weights.get(), + test_case.activations.data(), + has_clamp, + test_case.clamp_min, + test_case.clamp_max); + + // Test correctness + float tol = kTol; + if (has_kleidi) { + tol = kTolKleidiAI; + } + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], tol); } } } @@ -176,102 +177,56 @@ enum kai_kernel_id { i8mm_8x4x32 }; -template < - typename kernel_struct, - int m_step, - int mr, - int n_step, - int nr, - int kr, - int sr> -UKernelConfig get_ukernel_config_kleidi() { +template +UKernelConfig get_ukernel_config_kleidi_impl() { namespace op = torchao::kernels::cpu::aarch64::kleidi:: kai_matmul_clamp_f32_qai8dxp_qsi4c32p; + auto uk = kernel_struct::get_ukernel(); - assert(m_step == uk.get_m_step()); - assert(mr == uk.get_mr()); - assert(n_step == uk.get_n_step()); - assert(nr == uk.get_nr()); - assert(kr == uk.get_kr()); - assert(sr == uk.get_sr()); - return UKernelConfig{ + auto ukernel_config = UKernelConfig::make( op::get_preferred_alignement(), - n_step, - {/*weight_data_size_fn*/ &op::weight_data_size, - /*prepare_weight_data_fn*/ &op::prepare_weight_data}, - {{{m_step, - &op::activation_data_size, - &op::prepare_activation_data, - &kernel_struct::kernel}}}}; + uk.get_n_step(), + uk.get_nr(), + uk.get_kr(), + uk.get_sr(), + /*weight_nbit*/ 4, + /*has_weight_zeros*/ false, + /*has_bias*/ true, + &op::packed_weights_size, + &op::packed_weights_offset, + &op::pack_weights, + /*linear_configs*/ {}); + + ukernel_config.linear_configs[0] = UKernelConfig::linear_config_type{ + static_cast(uk.get_m_step()), + static_cast(uk.get_mr()), + &op::packed_activations_size, + &op::packed_activations_offset, + &op::pack_activations, + &kernel_struct::kernel}; + + return ukernel_config; } template UKernelConfig get_ukernel_config_kleidi() { #if defined(TORCHAO_ENABLE_ARM_I8MM) if constexpr (kernel_id == i8mm_4x8x32) { - constexpr int m_step = 4; - constexpr int mr = 4; - constexpr int n_step = 8; - constexpr int nr = 8; - constexpr int kr = 16; - constexpr int sr = 2; - return get_ukernel_config_kleidi< - matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - m_step, - mr, - n_step, - nr, - kr, - sr>(); + return get_ukernel_config_kleidi_impl< + matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm>(); } if constexpr (kernel_id == i8mm_8x4x32) { - constexpr int m_step = 8; - constexpr int mr = 8; - constexpr int n_step = 4; - constexpr int nr = 4; - constexpr int kr = 16; - constexpr int sr = 2; - return get_ukernel_config_kleidi< - matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - m_step, - mr, - n_step, - nr, - kr, - sr>(); + return get_ukernel_config_kleidi_impl< + matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm>(); } #endif // TORCHAO_ENABLE_ARM_I8MM if constexpr (kernel_id == dotprod_1x8x32) { - constexpr int m_step = 1; - constexpr int mr = 1; - constexpr int n_step = 8; - constexpr int nr = 8; - constexpr int kr = 16; - constexpr int sr = 2; - return get_ukernel_config_kleidi< - matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - m_step, - mr, - n_step, - nr, - kr, - sr>(); + return get_ukernel_config_kleidi_impl< + matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod>(); } if constexpr (kernel_id == dotprod_1x4x32) { - constexpr int m_step = 1; - constexpr int mr = 1; - constexpr int n_step = 4; - constexpr int nr = 4; - constexpr int kr = 16; - constexpr int sr = 2; - return get_ukernel_config_kleidi< - matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - m_step, - mr, - n_step, - nr, - kr, - sr>(); + return get_ukernel_config_kleidi_impl< + matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod>(); } throw std::runtime_error("Unsupported kernel_id"); } @@ -332,15 +287,11 @@ TEST(test_linear_8bit_act_xbit_weight, KNotDivisibleByGroupSize) { true /*has_weight_zeros*/, true /*has_bias*/, true /*has_clamp*/>(); - auto pack_weight_data_tiling_params = - get_default_pack_weight_data_tiling_params(ukernel_config, n); - EXPECT_THROW( { - pack_weight_data_operator( + pack_weights_operator( ukernel_config, - pack_weight_data_tiling_params, - /*packed_weight_data=*/nullptr, + /*packed_weights=*/nullptr, n, k, group_size, @@ -362,15 +313,12 @@ TEST(test_linear_8bit_act_xbit_weight, GroupSizeNotDivisibleBy16) { true /*has_weight_zeros*/, true /*has_bias*/, true /*has_clamp*/>(); - auto pack_weight_data_tiling_params = - get_default_pack_weight_data_tiling_params(ukernel_config, n); EXPECT_THROW( { - pack_weight_data_operator( + pack_weights_operator( ukernel_config, - pack_weight_data_tiling_params, - /*packed_weight_data=*/nullptr, + /*packed_weights=*/nullptr, n, k, group_size, From 7959ac33ea5471f06fbbd73a22a7168ad3b75db1 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Wed, 2 Apr 2025 09:51:27 -0700 Subject: [PATCH 06/30] quantized matmul Differential Revision: D71370592 Pull Request resolved: https://github.com/pytorch/ao/pull/1994 --- ...hannelwise_8bit_b_1x16x16_f32_smlal-impl.h | 384 ++++++++++++++++++ ...annelwise_8bit_b_1x8x16_f32_neondot-impl.h | 336 +++++++++++++++ .../kernels/cpu/aarch64/matmul/matmul.h | 74 ++++ .../kernels/cpu/aarch64/matmul/matmul_utils.h | 70 ++++ .../cpu/aarch64/quantization/quantize.cpp | 23 +- .../kernels/cpu/aarch64/tests/CMakeLists.txt | 9 + .../cpu/aarch64/tests/build_and_run_tests.sh | 1 + .../cpu/aarch64/tests/test_qmatmul.cpp | 229 +++++++++++ .../kernels/cpu/aarch64/tests/test_utils.h | 229 +++++++++-- 9 files changed, 1324 insertions(+), 31 deletions(-) create mode 100644 torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h create mode 100644 torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h create mode 100644 torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h create mode 100644 torchao/experimental/kernels/cpu/aarch64/matmul/matmul_utils.h create mode 100644 torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h b/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h new file mode 100644 index 0000000000..b83c28143f --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h @@ -0,0 +1,384 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include +#include + +#include +#include +#include + +namespace torchao::kernels::cpu::aarch64::quantized_matmul { +namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal::internal { + +namespace { +/* +This function loads int8x16_t value from a, and 8 int8x16_t values from b. +For each int8x16_t of b: +- subl to subtarct a_zero_point from a, to get a_low, a_high +- 4 int32x4 accumulated values +- for i in [0, 8]: + - load b[i] + - subl to subtarct b_zero_point from b, to get b_low, b_high + - smlal_lane to multiply a_low[i] and b_low_low. + - smlal_lane to multiply a_low[i] and b_low_high. + - smlal_lane to multiply a_low[i] and b_high_low. + - smlal_lane to multiply a_low[i] and b_high_high. + - This produces 2 int32x4_t values +- for i in [0, 8]: + - load b[i] + - subl to subtarct b_zero_point from b, to get b_low, b_high + - smlal_lane to multiply a_low[i] and b_low_low. + - smlal_lane to multiply a_low[i] and b_low_high. + - smlal_lane to multiply a_low[i] and b_high_low. + - smlal_lane to multiply a_low[i] and b_high_high. + - This produces 2 int32x4_t values +Possibly better to transpose 16x16 of b and use dotprod. Left for future. +*/ + +template +TORCHAO_ALWAYS_INLINE void block_mul_1x16x1( + const int16x4_t& a_vec, + const int8x16_t& b_vec, + const int8x16_t& b_zero_point_vec, + int32x4_t (&partial_sums)[4]) { + int16x8_t b_vec_low = + vsubl_s8(vget_low_s8(b_vec), vget_low_s8(b_zero_point_vec)); + int16x8_t b_vec_high = + vsubl_s8(vget_high_s8(b_vec), vget_high_s8(b_zero_point_vec)); + partial_sums[0] = + vmlal_lane_s16(partial_sums[0], vget_low_s16(b_vec_low), a_vec, lane); + partial_sums[1] = + vmlal_lane_s16(partial_sums[1], vget_high_s16(b_vec_low), a_vec, lane); + partial_sums[2] = + vmlal_lane_s16(partial_sums[2], vget_low_s16(b_vec_high), a_vec, lane); + partial_sums[3] = + vmlal_lane_s16(partial_sums[3], vget_high_s16(b_vec_high), a_vec, lane); +} + +void block_mul_1x16x16( + const int8_t* a, + const int8_t* b, + const size_t ldb, + const int8_t a_zero_point, + const int8_t* b_zero_point, + int32x4_t (&partial_sums)[4]) { + int8x16_t a_vec = vld1q_s8(a); + int8x8_t a_zero_point_vec = vdup_n_s8(a_zero_point); + int8x16_t b_zero_point_vec = vld1q_s8(b_zero_point); + int16x8_t a_vec_low = vsubl_s8(vget_low_s8(a_vec), a_zero_point_vec); + int16x8_t a_vec_high = vsubl_s8(vget_high_s8(a_vec), a_zero_point_vec); + + int8x16_t b_vec = vld1q_s8(b + 0 * ldb); + block_mul_1x16x1<0>( + vget_low_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 1 * ldb); + block_mul_1x16x1<1>( + vget_low_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 2 * ldb); + block_mul_1x16x1<2>( + vget_low_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 3 * ldb); + block_mul_1x16x1<3>( + vget_low_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 4 * ldb); + block_mul_1x16x1<0>( + vget_high_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 5 * ldb); + block_mul_1x16x1<1>( + vget_high_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 6 * ldb); + block_mul_1x16x1<2>( + vget_high_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 7 * ldb); + block_mul_1x16x1<3>( + vget_high_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums); + + // Second set of 8 channels + b_vec = vld1q_s8(b + 8 * ldb); + block_mul_1x16x1<0>( + vget_low_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 9 * ldb); + block_mul_1x16x1<1>( + vget_low_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 10 * ldb); + block_mul_1x16x1<2>( + vget_low_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 11 * ldb); + block_mul_1x16x1<3>( + vget_low_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 12 * ldb); + block_mul_1x16x1<0>( + vget_high_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 13 * ldb); + block_mul_1x16x1<1>( + vget_high_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 14 * ldb); + block_mul_1x16x1<2>( + vget_high_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 15 * ldb); + block_mul_1x16x1<3>( + vget_high_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums); +} + +TORCHAO_ALWAYS_INLINE void dequantize_1x16_int32_t( + const int32x4_t (&sums)[4], + const float* lhs_scales, + const float* rhs_scales, + float32x4_t (&outputs)[4]) { + float32x4_t scales_0123 = vmulq_n_f32(vld1q_f32(rhs_scales), lhs_scales[0]); + float32x4_t scales_4567 = + vmulq_n_f32(vld1q_f32(rhs_scales + 4), lhs_scales[0]); + float32x4_t scales_89ab = + vmulq_n_f32(vld1q_f32(rhs_scales + 8), lhs_scales[0]); + float32x4_t scales_cdef = + vmulq_n_f32(vld1q_f32(rhs_scales + 12), lhs_scales[0]); + + outputs[0] = vmulq_f32(vcvtq_f32_s32(sums[0]), scales_0123); + outputs[1] = vmulq_f32(vcvtq_f32_s32(sums[1]), scales_4567); + outputs[2] = vmulq_f32(vcvtq_f32_s32(sums[2]), scales_89ab); + outputs[3] = vmulq_f32(vcvtq_f32_s32(sums[3]), scales_cdef); +} + +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_transposed> +struct KernelImpl { + static void run( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride); +}; + +template <> +struct KernelImpl { + /** + * @brief Implements quantized matrix multiplication for 8-bit channelwise + * quantized matrices + * + * This specialized implementation handles the case where: + * - Both LHS and RHS have zero points (true, true) + * - Neither LHS nor RHS are transposed (false, false) + * + * The function performs a quantized matrix multiplication C = A * B where: + * - A is an m×k matrix (LHS) + * - B is a k×n matrix (RHS) + * - C is an m×n matrix (output) + * + * The implementation uses NEON intrinsics for vectorized computation and + * processes data in blocks of 16×16 for optimal performance on ARM + * architecture. + * + * @param m Number of rows in LHS and output + * @param n Number of columns in RHS and output + * @param k Number of columns in LHS and rows in RHS + * @param lhs Pointer to LHS matrix data (int8_t) + * @param lhs_stride_m Stride between rows of LHS + * @param rhs Pointer to RHS matrix data (int8_t) + * @param rhs_stride_n Stride between rows of RHS + * @param output Pointer to output matrix (float32_t) + * @param out_stride_m Stride between rows of output + * @param lhs_zero_points Zero points for LHS quantization (per-channel) + * @param rhs_zero_points Zero points for RHS quantization (per-channel) + * @param lhs_scales Scales for LHS quantization (per-channel) + * @param rhs_scales Scales for RHS quantization (per-channel) + * @param lhs_qparams_stride Stride for LHS quantization parameters + * @param rhs_qparams_stride Stride for RHS quantization parameters + */ + static void run( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride) { + // If lhs_zero_points and rhs_zero_points are not contiguous, transpose + std::unique_ptr lhs_zero_points_transposed = + std::make_unique(m); + std::unique_ptr lhs_scales_transposed = + std::make_unique(m); + if (lhs_qparams_stride > 1) { + utils::transpose_scales_and_zero_points( + lhs_zero_points, + lhs_scales, + lhs_zero_points_transposed.get(), + lhs_scales_transposed.get(), + m, + lhs_qparams_stride); + lhs_zero_points = lhs_zero_points_transposed.get(); + lhs_scales = lhs_scales_transposed.get(); + } + std::unique_ptr rhs_zero_points_transposed = + std::make_unique(n); + std::unique_ptr rhs_scales_transposed = + std::make_unique(n); + if (rhs_qparams_stride > 1) { + utils::transpose_scales_and_zero_points( + rhs_zero_points, + rhs_scales, + rhs_zero_points_transposed.get(), + rhs_scales_transposed.get(), + n, + rhs_qparams_stride); + rhs_zero_points = rhs_zero_points_transposed.get(); + rhs_scales = rhs_scales_transposed.get(); + } + + for (int m_idx = 0; m_idx < m; m_idx++) { + // Loop over 16 cols at a time + // Access to partial tiles must be protected:w + constexpr int nr = 16; + constexpr int kr = 16; + assert(n >= nr); + for (int n_idx = 0; n_idx < n; n_idx += nr) { + // If remaining is < nr, that must mean that (nr - remaining) items + // dont need to be computed. + // In order to avoid out-of-bounds access, we need to rewind n_indx a + // bit + // |-------------------|-------------------| + // 0-------------------8-------------------16 + // 0-------------------8-----10 + // If n = 10 and nr = 8 then at n_idx = 8, we need to rewind n_idx to + // 8 - (8 - 10) = 2 + int remaining = std::min(n - n_idx, nr); + n_idx = n_idx - (nr - remaining); + // Set activation_ptr to start of activation qvals for row m_idx + const int8_t* lhs_ptr = (const int8_t*)lhs + m_idx * lhs_stride_m; + const int8_t* rhs_ptr = (const int8_t*)rhs + n_idx; + int32x4_t int32_sums[nr / 4] = {vdupq_n_s32(0)}; + + // Loop k_idx by group + int k_idx = 0; + for (; (k_idx + kr) <= k; k_idx += kr) { + block_mul_1x16x16( + lhs_ptr, + rhs_ptr, + rhs_stride_n, + lhs_zero_points[m_idx], + rhs_zero_points + n_idx, + int32_sums); + lhs_ptr += kr; + rhs_ptr += kr * rhs_stride_n; + } + + int8x16_t b_zero_point_vec = vld1q_s8(rhs_zero_points + n_idx); + for (int ki = 0; ki < (k - k_idx); ++ki) { + // For each of the remaining k values + // Load 1 int8_t from lhs + // Load 16 int8_t from rhs + // And multiply + add into the 16 accumulators + // arranged as int32x4_t[4] + int16_t a_val = static_cast(lhs_ptr[ki]) - + static_cast(lhs_zero_points[m_idx]); + int8x16_t b_vec = vld1q_s8(rhs_ptr + ki * rhs_stride_n); + int16x8_t b_vec_low = + vsubl_s8(vget_low_s8(b_vec), vget_low_s8(b_zero_point_vec)); + int16x8_t b_vec_high = + vsubl_s8(vget_high_s8(b_vec), vget_high_s8(b_zero_point_vec)); + int32_sums[0] = + vmlal_n_s16(int32_sums[0], vget_low_s16(b_vec_low), a_val); + int32_sums[1] = + vmlal_n_s16(int32_sums[1], vget_high_s16(b_vec_low), a_val); + int32_sums[2] = + vmlal_n_s16(int32_sums[2], vget_low_s16(b_vec_high), a_val); + int32_sums[3] = + vmlal_n_s16(int32_sums[3], vget_high_s16(b_vec_high), a_val); + } + + float32x4_t res[4]; + dequantize_1x16_int32_t( + int32_sums, lhs_scales + m_idx, rhs_scales + n_idx, res); + + // Store result + // Because we adjust n_idx, we may end up writing the same location + // twice + float* store_loc = output + m_idx * out_stride_m + n_idx; + vst1q_f32(store_loc, res[0]); + vst1q_f32(store_loc + 4, res[1]); + vst1q_f32(store_loc + 8, res[2]); + vst1q_f32(store_loc + 12, res[3]); + } // n_idx + } // m_idx + } +}; + +} // namespace + +} // namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal::internal + +namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal { +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_transposed> +void kernel( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride) { + torchao::kernels::cpu::aarch64::quantized_matmul:: + channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal::internal:: + KernelImpl::run( + m, + n, + k, + lhs, + lhs_stride_m, + rhs, + rhs_stride_n, + output, + out_stride_m, + lhs_zero_points, + rhs_zero_points, + lhs_scales, + rhs_scales, + lhs_qparams_stride, + rhs_qparams_stride); +} +} // namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal +} // namespace torchao::kernels::cpu::aarch64::quantized_matmul + +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h b/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h new file mode 100644 index 0000000000..123b7723e4 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h @@ -0,0 +1,336 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include +#include + +#include +#include +#include + +namespace torchao::kernels::cpu::aarch64::quantized_matmul { +namespace channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot::internal { + +/* +This function loads int8x16_t value from a, and 8 int8x16_t values from b, and +computes 8 dot products, resulting in 8 int32x4_t values. +Furthermore the int8x16_t values from a are reduced via summing, resulting in +int32_t row_sum_a. Similar int8x16_t values from b are reduced via summing, +resulting in int32_t row_sum_b. +*/ +TORCHAO_ALWAYS_INLINE static void block_mul_1x8x16( + const int8_t* a, + const int8_t* b, + const size_t ldb, + int32x4_t (&partial_sums)[8], + int32_t& row_sum_a, + int32_t (&row_sum_b)[8]) { + int8x16_t a_vec = vld1q_s8(a); + row_sum_a = row_sum_a + vaddlvq_s8(a_vec); + +// godbolt (https://godbolt.org/z/9vbq1d1qY) shows this loops doesnt quantize +// get optimized by moving all the loads up in the unrolled loop. Just hoping +// OOO machine will take care of things Late replace this with macros so as to +// deconstruct the loop and do manual optimization. Or just write assembly. +#pragma unroll(8) + for (int i = 0; i < 8; ++i) { + int8x16_t b_vec = vld1q_s8(b + i * ldb); + row_sum_b[i] = row_sum_b[i] + vaddlvq_s8(b_vec); + partial_sums[i] = vdotq_s32(partial_sums[i], a_vec, b_vec); + } +} + +TORCHAO_ALWAYS_INLINE static void reduce_1x8_int32x4_t_sums( + const int32x4_t (&partial_sums)[8], + int32_t (&sums)[8]) { +#pragma unroll(8) + for (int i = 0; i < 8; ++i) { + sums[i] = vaddvq_s32(partial_sums[i]); + } +} + +TORCHAO_ALWAYS_INLINE static void dequantize_1x8_int32_t( + const int32_t (&sums)[8], + int32_t& row_sum_lhs, + int32_t (&row_sum_rhs)[8], + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int32_t k, + float32x4x2_t& outputs) { + int32x4_t vec_sum_0123 = vld1q_s32(sums); + int32x4_t vec_sum_4567 = vld1q_s32(sums + 4); + + int32x4_t row_sum_rhs_x_lhs_zp_0123 = + vmulq_n_s32(vld1q_s32(row_sum_rhs), (int32_t)lhs_zero_points[0]); + int32x4_t row_sum_rhs_x_lhs_zp_4567 = + vmulq_n_s32(vld1q_s32(row_sum_rhs + 4), (int32_t)lhs_zero_points[0]); + + // Extract rhs zero point in int8x8_t and convert to int32x4_t + int16x8_t rhs_zero_points_vec_01234567 = vmovl_s8(vld1_s8(rhs_zero_points)); + int32x4_t rhs_zero_points_vec_0123 = + vmovl_s16(vget_low_s16(rhs_zero_points_vec_01234567)); + int32x4_t rhs_zero_points_vec_4567 = + vmovl_s16(vget_high_s16(rhs_zero_points_vec_01234567)); + int32x4_t row_sum_lhs_x_rhs_zp_0123 = + vmulq_n_s32(rhs_zero_points_vec_0123, row_sum_lhs); + int32x4_t row_sum_lhs_x_rhs_zp_4567 = + vmulq_n_s32(rhs_zero_points_vec_4567, row_sum_lhs); + + int32x4_t zp_rhs_x_zp_lhs_0123 = + vmulq_n_s32(rhs_zero_points_vec_0123, k * (int32_t)lhs_zero_points[0]); + int32x4_t zp_rhs_x_zp_lhs_4567 = + vmulq_n_s32(rhs_zero_points_vec_4567, k * (int32_t)lhs_zero_points[0]); + + vec_sum_0123 = vsubq_s32(vec_sum_0123, row_sum_rhs_x_lhs_zp_0123); + vec_sum_0123 = vsubq_s32(vec_sum_0123, row_sum_lhs_x_rhs_zp_0123); + vec_sum_0123 = vaddq_s32(vec_sum_0123, zp_rhs_x_zp_lhs_0123); + + vec_sum_4567 = vsubq_s32(vec_sum_4567, row_sum_rhs_x_lhs_zp_4567); + vec_sum_4567 = vsubq_s32(vec_sum_4567, row_sum_lhs_x_rhs_zp_4567); + vec_sum_4567 = vaddq_s32(vec_sum_4567, zp_rhs_x_zp_lhs_4567); + + float32x4_t scales_0123 = vmulq_n_f32(vld1q_f32(rhs_scales), lhs_scales[0]); + float32x4_t scales_4567 = + vmulq_n_f32(vld1q_f32(rhs_scales + 4), lhs_scales[0]); + + outputs.val[0] = vmulq_f32(vcvtq_f32_s32(vec_sum_0123), scales_0123); + outputs.val[1] = vmulq_f32(vcvtq_f32_s32(vec_sum_4567), scales_4567); +} + +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_transposed> +struct KernelImpl { + static void run( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride); +}; + +template <> +struct KernelImpl { + /** + * @brief Executes a quantized matrix multiplication with channelwise + * quantization parameters + * + * This function performs matrix multiplication between two 8-bit quantized + * matrices with per-channel quantization parameters. It handles the following + * operations: + * 1. Transposes quantization parameters if they're not contiguous + * 2. Processes the matrices in blocks of 8 columns at a time + * 3. Uses NEON dot product instructions for efficient computation + * 4. Handles edge cases for remaining elements + * 5. Dequantizes the results to floating point + * + * @param m Number of rows in the output matrix + * @param n Number of columns in the output matrix + * @param k Number of columns in lhs / rows in rhs + * @param lhs Pointer to the left-hand side matrix (quantized int8) + * @param lhs_stride_m Stride between rows of the lhs matrix + * @param rhs Pointer to the right-hand side matrix (quantized int8) + * @param rhs_stride_n Stride between rows of the rhs matrix. Expects matrix + * to be transposed. Thus of size [n x k] + * @param output Pointer to the output matrix (float32) + * @param out_stride_m Stride between rows of the output matrix + * @param lhs_zero_points Zero points for lhs quantization (per-channel) + * @param rhs_zero_points Zero points for rhs quantization (per-channel) + * @param lhs_scales Scales for lhs quantization (per-channel) + * @param rhs_scales Scales for rhs quantization (per-channel) + * @param lhs_qparams_stride Stride for lhs quantization parameters + * @param rhs_qparams_stride Stride for rhs quantization parameters + */ + static void run( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride) { + // If lhs_zero_points and rhs_zero_points are not contiguous, transpose + std::unique_ptr lhs_zero_points_transposed = + std::make_unique(m); + std::unique_ptr lhs_scales_transposed = + std::make_unique(m); + if (lhs_qparams_stride > 1) { + utils::transpose_scales_and_zero_points( + lhs_zero_points, + lhs_scales, + lhs_zero_points_transposed.get(), + lhs_scales_transposed.get(), + m, + lhs_qparams_stride); + lhs_zero_points = lhs_zero_points_transposed.get(); + lhs_scales = lhs_scales_transposed.get(); + } + std::unique_ptr rhs_zero_points_transposed = + std::make_unique(n); + std::unique_ptr rhs_scales_transposed = + std::make_unique(n); + if (rhs_qparams_stride > 1) { + utils::transpose_scales_and_zero_points( + rhs_zero_points, + rhs_scales, + rhs_zero_points_transposed.get(), + rhs_scales_transposed.get(), + n, + rhs_qparams_stride); + rhs_zero_points = rhs_zero_points_transposed.get(); + rhs_scales = rhs_scales_transposed.get(); + } + + for (int m_idx = 0; m_idx < m; m_idx++) { + // Loop over 8 cols at a time + // Access to partial tiles must be protected:w + constexpr int nr = 8; + constexpr int kr = 16; + assert(n >= nr); + for (int n_idx = 0; n_idx < n; n_idx += nr) { + // If remaining is < nr, that must mean that (nr - remaining) items + // dont need to be computed. + // In order to avoid out-of-bounds access, we need to rewind n_indx a + // bit + // |-------------------|-------------------| + // 0-------------------8-------------------16 + // 0-------------------8-----10 + // If n = 10 and nr = 8 then at n_idx = 8, we need to rewind n_idx to + // 8 - (8 - 10) = 2 + int remaining = std::min(n - n_idx, nr); + n_idx = n_idx - (nr - remaining); + // Set activation_ptr to start of activation qvals for row m_idx + const int8_t* lhs_ptr = (const int8_t*)lhs + m_idx * lhs_stride_m; + const int8_t* rhs_ptr = (const int8_t*)rhs + n_idx * rhs_stride_n; + int32x4_t int32_sums[nr] = {vdupq_n_s32(0)}; + int32_t row_sum_lhs = 0; + int32_t row_sum_rhs[nr] = {0, 0, 0, 0, 0, 0, 0, 0}; + int32_t sums[nr]; + + // Loop k_idx by group + int k_idx = 0; + for (; (k_idx + kr) <= k; k_idx += kr) { + block_mul_1x8x16( + lhs_ptr, + rhs_ptr, + rhs_stride_n, + int32_sums, + row_sum_lhs, + row_sum_rhs); + lhs_ptr += kr; + rhs_ptr += kr; + } + + reduce_1x8_int32x4_t_sums(int32_sums, sums); + for (int ki = 0; ki < (k - k_idx); ++ki) { + row_sum_lhs += (int32_t)lhs_ptr[ki]; + } + for (int ni = 0; ni < nr; ++ni) { + for (int ki = 0; ki < (k - k_idx); ++ki) { + sums[ni] += (int32_t)lhs_ptr[ki] * + (int32_t)(rhs_ptr + ni * rhs_stride_n)[ki]; + row_sum_rhs[ni] += (int32_t)(rhs_ptr + ni * rhs_stride_n)[ki]; + } + } + + float32x4x2_t res; + dequantize_1x8_int32_t( + sums, + row_sum_lhs, + row_sum_rhs, + lhs_zero_points + m_idx, + rhs_zero_points + n_idx, + lhs_scales + m_idx, + rhs_scales + n_idx, + k, + res); + + // Store result + // Because we adjust n_idx, we may end up writing the same location + // twice + float* store_loc = output + m_idx * out_stride_m + n_idx; + vst1q_f32(store_loc, res.val[0]); + vst1q_f32(store_loc + 4, res.val[1]); + } // n_idx + } // m_idx + } +}; + +} // namespace + // channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot::internal + +namespace channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot { +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_transposed> +void kernel( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride) { + torchao::kernels::cpu::aarch64::quantized_matmul:: + channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot::internal:: + KernelImpl::run( + m, + n, + k, + lhs, + lhs_stride_m, + rhs, + rhs_stride_n, + output, + out_stride_m, + lhs_zero_points, + rhs_zero_points, + lhs_scales, + rhs_scales, + lhs_qparams_stride, + rhs_qparams_stride); +} +} // namespace channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot +} // namespace torchao::kernels::cpu::aarch64::quantized_matmul + +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h b/torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h new file mode 100644 index 0000000000..4005dee564 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h @@ -0,0 +1,74 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +// TODO: this file will be deleted and replaced by +// torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/include.h +// It exists now to prevent breaking existing code in the interim. + +#pragma once + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include + +namespace torchao::kernels::cpu::aarch64::quantized_matmul { +namespace channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot { + +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_tranposed> +void kernel( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride); + +} // namespace channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot + +namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal { + +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_tranposed> +void kernel( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride); + +} // namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal +} // namespace torchao::kernels::cpu::aarch64::quantized_matmul + +#include +#include + +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/matmul_utils.h b/torchao/experimental/kernels/cpu/aarch64/matmul/matmul_utils.h new file mode 100644 index 0000000000..68ab912705 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/matmul/matmul_utils.h @@ -0,0 +1,70 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include + +namespace torchao::kernels::cpu::aarch64::quantized_matmul { +namespace utils { + +TORCHAO_ALWAYS_INLINE static void transpose_scales_and_zero_points( + const int8_t* zero_points, + const float* scales, + int8_t* zero_points_transposed, + float* scales_transposed, + const int m, + const int stride_m) { + // Process 8 elements at a time using NEON + int i = 0; + for (; i + 8 <= m; i += 8) { + // Load 8 zero points with stride_m + int8x8_t zp = { + zero_points[0 * stride_m], + zero_points[1 * stride_m], + zero_points[2 * stride_m], + zero_points[3 * stride_m], + zero_points[4 * stride_m], + zero_points[5 * stride_m], + zero_points[6 * stride_m], + zero_points[7 * stride_m]}; + zero_points += 8 * stride_m; + // Store contiguously + vst1_s8(zero_points_transposed + i, zp); + + // Load 8 scales with stride_m + float32x4_t scales_lo = { + scales[0 * stride_m], + scales[1 * stride_m], + scales[2 * stride_m], + scales[3 * stride_m]}; + float32x4_t scales_hi = { + scales[4 * stride_m], + scales[5 * stride_m], + scales[6 * stride_m], + scales[7 * stride_m]}; + scales += 8 * stride_m; + // Store contiguously + vst1q_f32(scales_transposed + i, scales_lo); + vst1q_f32(scales_transposed + i + 4, scales_hi); + } + + // Handle remaining elements + for (; i < m; i++) { + zero_points_transposed[i] = zero_points[0]; + scales_transposed[i] = scales[0]; + zero_points += stride_m; + scales += stride_m; + } +} + +} // namespace utils +} // namespace torchao::kernels::cpu::aarch64::quantized_matmul + +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp b/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp index 65416fdf1d..3460d67fba 100644 --- a/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include void torchao::quantization::get_qvals_range( @@ -64,8 +65,6 @@ void torchao::kernels::cpu::aarch64::quantization::quantize( int8_t zero, int8_t qmin, int8_t qmax) { - assert(size % 8 == 0); - float32_t invScale = 1.0 / (scale + 1e-16); float32x4_t vec_zero = vdupq_n_f32(zero); float32x4_t vec_invScale = vdupq_n_f32(invScale); @@ -78,7 +77,8 @@ void torchao::kernels::cpu::aarch64::quantization::quantize( int16x4_t vec_qval_s16_0; int16x4_t vec_qval_s16_1; - for (int i = 0; i < size; i += 8) { + int i = 0; + for (; (i + 8) < size; i += 8) { ////////////////////////////////////// // Quantize first 4 element chunk to int16 ////////////////////////////////////// @@ -112,6 +112,23 @@ void torchao::kernels::cpu::aarch64::quantization::quantize( int8x8_t vec_qval_s8_01 = vqmovn_s16(vec_qval_s16_01); vst1_s8(qvals + i, vec_qval_s8_01); } + auto curr_rounding_mode = fegetround(); + fesetround(FE_TONEAREST); + for (; i < size; ++i) { + // Quantize remaining elements using scalar code + float32_t val = vals[i]; + float32_t qval_f32 = zero + val * invScale; + int32_t qval_s32 = static_cast(std::nearbyint(qval_f32)); + + // Clip to qmin and qmax + qval_s32 = std::max( + static_cast(qmin), + std::min(qval_s32, static_cast(qmax))); + + // Store the quantized value + qvals[i] = static_cast(qval_s32); + } + fesetround(int(curr_rounding_mode)); } #endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt index 5b6ba2ab98..a01afac68f 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt @@ -119,6 +119,14 @@ target_link_libraries( dep ) +add_executable(test_qmatmul test_qmatmul.cpp) +target_link_libraries( + test_qmatmul + PRIVATE + GTest::gtest_main + dep +) + include(GoogleTest) gtest_discover_tests(test_quantization) gtest_discover_tests(test_reduction) @@ -127,3 +135,4 @@ gtest_discover_tests(test_linear) gtest_discover_tests(test_valpacking) gtest_discover_tests(test_embedding) gtest_discover_tests(test_weight_packing) +gtest_discover_tests(test_qmatmul) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh index 4b2181d7cc..1898e8b535 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh +++ b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh @@ -61,3 +61,4 @@ ${CMAKE_OUT}/test_linear ${CMAKE_OUT}/test_valpacking ${CMAKE_OUT}/test_embedding ${CMAKE_OUT}/test_weight_packing +${CMAKE_OUT}/test_qmatmul diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp new file mode 100644 index 0000000000..1b3e11156f --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp @@ -0,0 +1,229 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include + +#include +#include +#include + +float kTol = 0.0001; + +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_transposed> +struct test_channelwise_8bit_channelwise_8bit_b { + static void Run(int m, int k, int n); +}; + +template +struct test_channelwise_8bit_channelwise_8bit_b< + a_has_zeros, + b_has_zeros, + false, + true> { + static void Run(int m, int k, int n, int stride = 1) { + auto test_case = + torchao::channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case:: + generate(m, k, n, a_has_zeros, a_has_zeros, false, true, stride); + + using namespace torchao::kernels::cpu::aarch64::quantized_matmul:: + channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot; + + std::vector output(m * n); + kernel( + m, + n, + k, + test_case.lhs_qvals.data(), + k * stride /*lsh_stride_m*/, + test_case.rhs_qvals.data(), + k * stride /*rsh_stride_n*/, + output.data(), + n /*out_stride_n*/, + test_case.lhs_zeros.data(), + test_case.rhs_zeros.data(), + test_case.lhs_scales.data(), + test_case.rhs_scales.data(), + stride, /*lhs qparams stride*/ + stride /*rhs qparams stride*/); + + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } + } +}; + +template +struct test_channelwise_8bit_channelwise_8bit_b< + a_has_zeros, + b_has_zeros, + false, + false> { + static void Run(int m, int k, int n, int stride = 1) { + auto test_case = + torchao::channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case:: + generate(m, k, n, a_has_zeros, a_has_zeros, false, false); + + using namespace torchao::kernels::cpu::aarch64::quantized_matmul:: + channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal; + + std::vector output(m * n); + kernel( + m, + n, + k, + test_case.lhs_qvals.data(), + k /*lsh_stride_m*/, + test_case.rhs_qvals.data(), + n /*rsh_stride_n*/, + output.data(), + n /*out_stride_n*/, + test_case.lhs_zeros.data(), + test_case.rhs_zeros.data(), + test_case.lhs_scales.data(), + test_case.rhs_scales.data(), + stride, /*lhs qparams stride*/ + stride /*rhs qparams stride*/); + + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } + } +}; + +TEST(test_channelwise_8bit_channelwise_8bit_b, TransposedBWithZeroPoints) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/1, /*k=*/128, /*n=*/16); +} + +TEST(test_channelwise_8bit_channelwise_8bit_b, TransposeBWithZeroPointsLargeM) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/128, /*n=*/16); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TransposedBWithZeroPointsOddSizes) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/24); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TransposedBWithZeroPointsOddSizes2) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/19); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TransposedBWithZeroPointsStrided) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/1, /*k=*/128, /*n=*/16, 5); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TransposedBWithZeroPointsOddSizes2Strided) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/19, 10); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TransposedBWithZeroPointsOddSizes2Strided2) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/3, /*k=*/64, /*n=*/24, 7); +} + +TEST(test_channelwise_8bit_channelwise_8bit_b, NoTransposedWithZeroPoints) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + false /*b_transposed*/>:: + Run( + /*m=*/1, /*k=*/128, /*n=*/16); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + NoTransposedWithZeroPointsLargeM) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + false /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/128, /*n=*/16); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + NoTransposedWithZeroPointsOddSizes) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + false /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/24); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + NoTransposedWithZeroPointsOddSizes2) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + false /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/19); +} + +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h index 4720b68fb0..80ddcb690d 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h @@ -84,6 +84,59 @@ inline float get_float_from_bf16(uint16_t bf16) { return f; } +namespace { +auto generate_per_token_quantized_tensor( + int m, + int n, + bool transposed = false) { + auto activations = get_random_vector(m * n, -1.0, 1.0); + auto activation_qvals = std::vector(m * n, 0); + auto activation_scales = std::vector(m, 0); + auto activation_zeros = std::vector(m, 0); + + // Quantize activations with 8-bit asymmetric + // TODO: replace with generic function that does not use aarch64 + // quantize method after we combine with torchao + int qmin, qmax, zero; + float vmin, vmax, scale; + torchao::quantization::get_qvals_range( + qmin, qmax, /*nbit=*/8, /*is_symmetric=*/false); + for (int m_idx = 0; m_idx < m; m_idx++) { + torchao::kernels::cpu::aarch64::reduction::find_min_and_max( + vmin, vmax, /*vals=*/activations.data() + m_idx * n, /*size=*/n); + torchao::quantization::get_scale_and_zero( + scale, zero, vmin, vmax, qmin, qmax); + activation_scales[m_idx] = scale; + activation_zeros[m_idx] = zero; + torchao::kernels::cpu::aarch64::quantization::quantize( + /*qvals=*/activation_qvals.data() + m_idx * n, + /*vals=*/activations.data() + m_idx * n, + /*size=*/n, + scale, + zero, + qmin, + qmax); + } + if (transposed) { + auto activations_t = std::vector(m * n, 0); + auto activation_qvals_t = std::vector(m * n, 0); + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + int activation_idx = m_idx * n + n_idx; + int tranposed_idx = n_idx * m + m_idx; + activations_t[tranposed_idx] = activations[activation_idx]; + activation_qvals_t[tranposed_idx] = activation_qvals[activation_idx]; + } + } + activations = activations_t; + activation_qvals = activation_qvals_t; + } + + return std::make_tuple( + activations, activation_qvals, activation_scales, activation_zeros); +} +} // namespace + struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { int m; int k; @@ -182,34 +235,8 @@ struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { // weights is k x n (stored in column-major) // Generate activations - auto activations = get_random_vector(m * k, -1.0, 1.0); - auto activation_qvals = std::vector(m * k, 0); - auto activation_scales = std::vector(m, 0); - auto activation_zeros = std::vector(m, 0); - - // Quantize activations with 8-bit asymmetric - // TODO: replace with generic function that does not use aarch64 - // quantize method after we combine with torchao - int qmin, qmax, zero; - float vmin, vmax, scale; - torchao::quantization::get_qvals_range( - qmin, qmax, /*nbit=*/8, /*is_symmetric=*/false); - for (int m_idx = 0; m_idx < m; m_idx++) { - torchao::kernels::cpu::aarch64::reduction::find_min_and_max( - vmin, vmax, /*vals=*/activations.data() + m_idx * k, /*size=*/k); - torchao::quantization::get_scale_and_zero( - scale, zero, vmin, vmax, qmin, qmax); - activation_scales[m_idx] = scale; - activation_zeros[m_idx] = zero; - torchao::kernels::cpu::aarch64::quantization::quantize( - /*qvals=*/activation_qvals.data() + m_idx * k, - /*vals=*/activations.data() + m_idx * k, - /*size=*/k, - scale, - zero, - qmin, - qmax); - } + auto [activations, activation_qvals, activation_scales, activation_zeros] = + generate_per_token_quantized_tensor(m, k); // Generate weights assert(k % weight_group_size == 0); @@ -219,6 +246,8 @@ struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { auto weight_scales = std::vector(n_weight_groups, 0.0); auto weight_zeros = std::vector(n_weight_groups, 0); + int qmin, qmax, zero; + float vmin, vmax, scale; // Quantize weights with weight_nbit // TODO: replace with generic function that does not use aarch64 // quantize method after we combine with torchao @@ -322,6 +351,150 @@ struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { } }; +struct channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case { + int m; + int k; + int n; + int stride; + + bool lhs_has_zeros; + bool rhs_has_zeros; + bool lhs_is_transposed; + bool rhs_is_transposed; + + std::vector expected_output; + + std::vector lhs; + std::vector lhs_qvals; + std::vector lhs_scales; + std::vector lhs_zeros; + + std::vector rhs; + std::vector rhs_qvals; + std::vector rhs_scales; + std::vector rhs_zeros; + + channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case( + int m_, + int k_, + int n_, + int stride_, + bool lhs_has_zeros_, + bool rhs_has_zeros_, + bool lhs_is_transposed_, + bool rhs_is_transposed_, + std::vector expected_output_, + std::vector lhs_, + std::vector lhs_qvals_, + std::vector lhs_scales_, + std::vector lhs_zeros_, + std::vector rhs_, + std::vector rhs_qvals_, + std::vector rhs_scales_, + std::vector rhs_zeros_) + : m(m_), + k(k_), + n(n_), + stride(stride_), + lhs_has_zeros(lhs_has_zeros_), + rhs_has_zeros(rhs_has_zeros_), + lhs_is_transposed(lhs_is_transposed_), + rhs_is_transposed(rhs_is_transposed_), + expected_output(expected_output_), + lhs(lhs_), + lhs_qvals(lhs_qvals_), + lhs_scales(lhs_scales_), + lhs_zeros(lhs_zeros_), + rhs(rhs_), + rhs_qvals(rhs_qvals_), + rhs_scales(rhs_scales_), + rhs_zeros(rhs_zeros_) { + assert(expected_output.size() == m * n); + assert(lhs.size() == m * stride * k); + assert(lhs_qvals.size() == m * stride * k); + assert(lhs_scales.size() == m * stride); + assert(lhs_zeros.size() == m * stride); + assert(rhs.size() == n * stride * k); + assert(rhs_qvals.size() == n * stride * k); + assert(rhs_scales.size() == n * stride); + assert(rhs_zeros.size() == n * stride); + } + + static channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case generate( + int m, + int k, + int n, + bool lhs_has_zeros, + bool rhs_has_zeros, + bool lhs_is_transposed, + // rhs_is_transposed means generated b matrix is mxk instead of kxm + bool rhs_is_transposed, + int stride = 1) { + assert(!lhs_is_transposed); + assert(lhs_has_zeros); + assert(rhs_has_zeros); + // !Rhs transposed was considered if we were doing quantized(softmax(q@k)) @ + // quantized(v) Since v would have been [B, H, S, D]. And [S, D] would be + // rhs matrix which is not transposed when considered matmul terminology + // because for matmul we would have A[S_q, S] x B[S, D]. + // It would have been transposed if A[S_q, S] x B[D, S]. + assert(rhs_is_transposed || stride == 1); + // Generate activations + auto [lhs, lhs_qvals, lhs_scales, lhs_zeros] = + generate_per_token_quantized_tensor(m * stride, k); + + auto [rhs, rhs_qvals, rhs_scales, rhs_zeros] = + generate_per_token_quantized_tensor(n * stride, k, !rhs_is_transposed); + // Above function produces nxk matrix and to produce kxn you need transposed + // = true. we do !rhs_is_transposed becaues when rhs_is_transposed = true + // the shape should be nxk instead of kxn. + + // Compute expected output + std::vector expected_output(m * n); + + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + float res = 0.0; + for (int k_idx = 0; k_idx < k; k_idx++) { + int lhs_idx = m_idx * stride * k + k_idx; + int rhs_idx = k_idx * stride * n + n_idx * stride; + if (rhs_is_transposed) { + rhs_idx = n_idx * stride * k + k_idx; + } + float lhs_dequant = lhs_scales[m_idx * stride] * + (lhs_qvals[lhs_idx] - lhs_zeros[m_idx * stride]); + + float rhs_dequant = rhs_scales[n_idx * stride] * + (rhs_qvals[rhs_idx] - rhs_zeros[n_idx * stride]); + + res += lhs_dequant * rhs_dequant; + } + expected_output[m_idx * n + n_idx] = res; + } + } + + // Return test case + return channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case( + m, + k, + n, + stride, + lhs_has_zeros, + rhs_has_zeros, + lhs_is_transposed, + rhs_is_transposed, + expected_output, + lhs, + lhs_qvals, + lhs_scales, + lhs_zeros, + rhs, + rhs_qvals, + rhs_scales, + rhs_zeros); + } +}; + template struct lowbit_embedding_test_case { int num_embeddings; From 0304a52085a388eb34605b8010cd867af160b37e Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Wed, 2 Apr 2025 10:25:19 -0700 Subject: [PATCH 07/30] Allow builds on less than sm75 raise runtime failure (#1999) stack-info: PR: https://github.com/pytorch/ao/pull/1999, branch: drisspg/stack/45 --- torchao/csrc/cuda/fp6_llm/fp6_linear.cu | 52 +++++++++++++++------ torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh | 19 ++++++-- torchao/ops.py | 13 ++++++ 3 files changed, 65 insertions(+), 19 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu index 531e1ba7e6..26f6494220 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -21,6 +21,7 @@ // // MODIFICATION NOTE (2024-09-25): added SM75 support (https://github.com/pytorch/ao/pull/942): // - Modified the TilingConfig parameters for SM75 to deal with smaller shared memory +// - Added proper architecture check at both host and device level // @@ -98,7 +99,24 @@ void fpx_linear_kernel(cudaStream_t stream, static_assert(std::is_same::value || std::is_same::value, "Type must be 'half' or '__nv_bfloat16'"); assert(M_Global % 256 == 0); assert(K_Global % 64 == 0); - assert(N_Global>0); + assert(N_Global > 0); + + // Check GPU Compute Capability before proceeding + int device, major, minor; + CHECK_CUDA(cudaGetDevice(&device)); + CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device)); + CHECK_CUDA(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device)); + + // Early exit with error for unsupported architectures + if ((major < 7) || (major == 7 && minor < 5)) { + TORCH_CHECK(false, "Quant-LLM Error: This kernel requires GPU with SM75 (Turing) or higher architecture. " + "Your current device has SM", major, minor, " which is not supported."); + } + + const bool is_sm75_gpu = (major == 7) && (minor == 5); + if (is_sm75_gpu && std::is_same::value) { + TORCH_CHECK(false, "Quant-LLM Error: BFloat16 inputs are not supported on SM75 (Turing) GPUs."); + } // Work around to support more N shapes: size_t N_PowerOf2; @@ -109,17 +127,6 @@ void fpx_linear_kernel(cudaStream_t stream, if(N_Global>64 && N_Global<=128) N_PowerOf2 = 128; if(N_Global>128) N_PowerOf2 = ((N_Global-1)/128+1) * 128; - // Check GPU Compute Capability - int device, major, minor; - CHECK_CUDA(cudaGetDevice(&device)); - CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device)); - CHECK_CUDA(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device)); - const bool is_sm75_gpu = (major == 7) && (minor == 5); - if (is_sm75_gpu && std::is_same::value) - TORCH_CHECK(false, "Bfloat16 inputs are not supported for SM75"); - if ((major < 7) || (major == 7 && minor < 5)) - TORCH_CHECK(false, "FP6LLM_API Error: FP6LLM requires GPU with SM75 or higher!\n"); - if (is_sm75_gpu && (N_PowerOf2 == 64 || N_PowerOf2 == 128 || N_PowerOf2 % 128 == 0)) { // For SM75 and N >= 64, we use a different TilingConfig to deal with smaller shared memory. if (Split_K == 1) { @@ -136,7 +143,7 @@ void fpx_linear_kernel(cudaStream_t stream, case 64: Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; case 128: Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; default: if (N_PowerOf2 % 128 != 0) { - TORCH_CHECK(false, "FP6LLM_API Error: Unsupported N dimension ", N_PowerOf2); + TORCH_CHECK(false, "Quant-LLM Error: Unsupported N dimension ", N_PowerOf2); } Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; } @@ -149,7 +156,7 @@ void fpx_linear_kernel(cudaStream_t stream, case 64: Kernel_Ex, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; case 128: Kernel_Ex, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; default: if (N_PowerOf2 % 128 != 0) { - TORCH_CHECK(false, "FP6LLM_API Error: Unsupported N dimension ", N_PowerOf2); + TORCH_CHECK(false, "Quant-LLM Error: Unsupported N dimension ", N_PowerOf2); } Kernel_Ex, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; } @@ -210,6 +217,23 @@ torch::Tensor fp_eXmY_linear_forward_cuda( torch::Tensor _scales, int64_t splitK=1) { + // Check GPU Compute Capability before proceeding + int device, major, minor; + CHECK_CUDA(cudaGetDevice(&device)); + CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device)); + CHECK_CUDA(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device)); + + // Early exit with error for unsupported architectures + if ((major < 7) || (major == 7 && minor < 5)) { + TORCH_CHECK(false, "Quant-LLM Error: This kernel requires GPU with SM75 (Turing) or higher architecture. " + "Your current device has SM", major, minor, " which is not supported."); + } + + const bool is_sm75_gpu = (major == 7) && (minor == 5); + if (is_sm75_gpu && _in_feats.scalar_type() == at::ScalarType::BFloat16) { + TORCH_CHECK(false, "Quant-LLM Error: BFloat16 inputs are not supported on SM75 (Turing) GPUs."); + } + const int64_t NBITS = 1 + EXPONENT + MANTISSA; int num_in_feats = _in_feats.size(0); int num_in_channels = _in_feats.size(1); diff --git a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh index d4be92b227..096bdc0d7f 100644 --- a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh +++ b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh @@ -51,17 +51,14 @@ * B: col major, FP16 * C: col major, FP16 */ - template +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 +template __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, const half *B, OutputDataType* C, const size_t M_Global, const size_t N_Global, const size_t K_Global, int Split_K) { - #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 - static_assert(false, "Quant-LLM kernel: At least Turing generation (sm75) is required."); - // __trap(); // fails at runtime instead of compile time - #endif #ifdef DEBUG_MODE assert(K_Global%TilingConfig::TILE_K==0); assert(M_Global%TilingConfig::TILE_M==0); @@ -233,3 +230,15 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, } } } +#else +// Stub implementation for older architectures +template +__global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, + const half *B, + OutputDataType* C, + const size_t M_Global, const size_t N_Global, const size_t K_Global, + int Split_K) +{ +// NOOP, should never actually be called +} +#endif diff --git a/torchao/ops.py b/torchao/ops.py index 34a97d03f5..5bc71321ac 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -71,6 +71,13 @@ def decorator(func): return decorator +@functools.lru_cache +def cached_compute_capability(): + device_props = torch.cuda.get_device_properties(torch.cuda.current_device()) + compute_capability = device_props.major * 10 + device_props.minor + return compute_capability + + def quant_llm_linear( EXPONENT: int, MANTISSA: int, @@ -93,6 +100,12 @@ def quant_llm_linear( Returns output of linear layer """ + # Check if we're on a supported architecture (sm7.5 or higher) + compute_capability = cached_compute_capability() + torch._check( + compute_capability >= 75, + lambda: f"quant_llm_linear requires sm7.5+ GPU architecture, but current device has sm{compute_capability}", + ) return torch.ops.torchao.quant_llm_linear.default( EXPONENT, MANTISSA, _in_feats, _weights, _scales, splitK ) From 3f89080230af4f4d6925914c21ba4c6bcfb4455d Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 2 Apr 2025 11:27:34 -0700 Subject: [PATCH 08/30] Skip galore test if not cuda (#2003) Summary: fixing CI before branch cut Test Plan: python test/quantization/test_galore_quant.py and CI Reviewers: Subscribers: Tasks: Tags: --- test/quantization/test_galore_quant.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/quantization/test_galore_quant.py b/test/quantization/test_galore_quant.py index a67f7775b1..0ebc356114 100644 --- a/test/quantization/test_galore_quant.py +++ b/test/quantization/test_galore_quant.py @@ -38,6 +38,7 @@ @pytest.mark.skip("skipping for now, see comments below") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") @pytest.mark.parametrize( "dim1,dim2,dtype,signed,blocksize", TEST_CONFIGS, @@ -89,6 +90,7 @@ def test_galore_quantize_blockwise(dim1, dim2, dtype, signed, blocksize): TEST_CONFIGS, ) @skip_if_rocm("ROCm enablement in progress") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") def test_galore_dequant_blockwise(dim1, dim2, dtype, signed, blocksize): g = torch.randn(dim1, dim2, device="cuda", dtype=dtype) * 0.01 From 4e8f7f88ead1416e139201536db7b470e9882d6d Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Wed, 2 Apr 2025 11:28:07 -0700 Subject: [PATCH 09/30] Fix experimental CI (#2005) * up * up --- .github/workflows/torchao_experimental_test.yml | 11 +++-------- dev-requirements.txt | 3 +++ 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/.github/workflows/torchao_experimental_test.yml b/.github/workflows/torchao_experimental_test.yml index 8d274b62e7..0cb470901e 100644 --- a/.github/workflows/torchao_experimental_test.yml +++ b/.github/workflows/torchao_experimental_test.yml @@ -37,9 +37,7 @@ jobs: # of torch and torchao, which we do not want to use pip install executorch pip install torch==2.7.0.dev20250311 --index-url "https://download.pytorch.org/whl/nightly/cpu" --force-reinstall - pip install numpy - pip install pytest - pip install parameterized + pip install -r dev-requirements.txt USE_CPP=1 TOCHAO_BUILD_KLEIDIAI=1 pip install . - name: Run python tests run: | @@ -99,11 +97,8 @@ jobs: python -c "import torch; print(torch.__version__)" - name: Install requirements run: | - pip install cmake - pip install parameterized - pip install pyyaml - pip install numpy - pip install importlib-metadata + pip install -r dev-requirements.txt + pip install pyyaml importlib-metadata - name: Print pip freeze run: | pip freeze diff --git a/dev-requirements.txt b/dev-requirements.txt index f5b1599ffa..1982d76795 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -26,6 +26,9 @@ importlib_metadata # Custom CUDA Extensions ninja +# CPU kernels +cmake<4.0.0,>=3.19.0 + # Linting ruff==0.6.8 pre-commit From 97f6618b25053d9a3ac71578f9b9f88ee5e6d382 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Wed, 2 Apr 2025 12:15:41 -0700 Subject: [PATCH 10/30] Add fp32xint8 matmul Differential Revision: D71370597 Pull Request resolved: https://github.com/pytorch/ao/pull/2004 --- ...input_channelwise_8bit_b_1x16x4_f32_impl.h | 275 ++++++++++++++++++ .../kernels/cpu/aarch64/matmul/matmul.h | 21 ++ .../cpu/aarch64/tests/test_qmatmul.cpp | 185 ++++++++++++ .../kernels/cpu/aarch64/tests/test_utils.h | 17 +- 4 files changed, 489 insertions(+), 9 deletions(-) create mode 100644 torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h b/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h new file mode 100644 index 0000000000..389abb32a5 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h @@ -0,0 +1,275 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include +#include + +#include +#include +#include + +namespace torchao::kernels::cpu::aarch64::quantized_matmul { +namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32::internal { + +namespace { + +/* +This function loads float32x4_t value from a, and 16 int8x16_t values from b. +For each int8x16_t of b: +- 4 float32x4 accumulated values +- load 4 a in float32x4_t +- [The following repeats for each of the 4 lanes of a] +- for i in [0, 4]: + - load b[i] in int8x16_t + - subl to subtract b_zero_point from b, to get b_low, b_high + - vmovl to get b_low_low, b_low_high, b_high_low, b_high_high + - vcvtq to convert to float32x4_t, we will have 4 of these. +- for i in [0, 4]: for each of the 4 float32x4_t of b: + - vfmaq_lane_fp32 to multiply a[lane] and b[i] + - vfmaq_lane_fp32 to multiply a[lane] and b[i] + - vfmaq_lane_fp32 to multiply a[lane] and b[i] + - vfmaq_lane_fp32 to multiply a[lane] and b[i] +- By doing the above 4 times (lane=[0-3]), we used all values along k dim of a + and accumulated 4 float32x4_t values +*/ +TORCHAO_ALWAYS_INLINE void block_mul_1x16x1( + const float32_t a, + const int8x16_t& b_vec, + const int8_t b_zero_point, + const float b_scale, + float32x4_t (&partial_sums)[4]) { + int8x8_t b_zero_point_vec = vdup_n_s8(b_zero_point); + int16x8_t b_vec_low = vsubl_s8(vget_low_s8(b_vec), b_zero_point_vec); + int16x8_t b_vec_high = vsubl_s8(vget_high_s8(b_vec), b_zero_point_vec); + float32x4_t b_vec_low_low = vcvtq_f32_s32(vmovl_s16(vget_low_s16(b_vec_low))); + float32x4_t b_vec_low_high = + vcvtq_f32_s32(vmovl_s16(vget_high_s16(b_vec_low))); + float32x4_t b_vec_high_low = + vcvtq_f32_s32(vmovl_s16(vget_low_s16(b_vec_high))); + float32x4_t b_vec_high_high = + vcvtq_f32_s32(vmovl_s16(vget_high_s16(b_vec_high))); + b_vec_low_low = vmulq_n_f32(b_vec_low_low, b_scale); + b_vec_low_high = vmulq_n_f32(b_vec_low_high, b_scale); + b_vec_high_low = vmulq_n_f32(b_vec_high_low, b_scale); + b_vec_high_high = vmulq_n_f32(b_vec_high_high, b_scale); + + partial_sums[0] = vfmaq_n_f32(partial_sums[0], b_vec_low_low, a); + partial_sums[1] = vfmaq_n_f32(partial_sums[1], b_vec_low_high, a); + partial_sums[2] = vfmaq_n_f32(partial_sums[2], b_vec_high_low, a); + partial_sums[3] = vfmaq_n_f32(partial_sums[3], b_vec_high_high, a); +} + +void block_mul_1x16x4( + const float32_t* a, + const int8_t* b, + const size_t ldb, + const int8_t* b_zero_point, + const float* b_scale, + float32x4_t (&partial_sums)[4]) { + #pragma unroll(8) + for (int i = 0; i < 4; i++) { + int8x16_t b_vec = vld1q_s8(b + i * ldb); + block_mul_1x16x1(a[i], b_vec, b_zero_point[i], b_scale[i], partial_sums); + } +} + +} // namespace + +template +struct KernelImpl { + static void run( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* rhs_zero_points, + const float* rhs_scales, + const float beta, + const int rhs_qparams_stride); +}; + +template <> +struct KernelImpl { + static void run( + int m, + int n, + int k, + const float* lhs, + int lhs_stride_m, + const int8_t* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* rhs_zero_points, + const float* rhs_scales, + const float beta, + const int rhs_qparams_stride) { + std::unique_ptr rhs_zero_points_transposed = std::make_unique(k); + std::unique_ptr rhs_scales_transposed = std::make_unique(k); + if (rhs_qparams_stride > 1) { + utils::transpose_scales_and_zero_points( + rhs_zero_points, + rhs_scales, + rhs_zero_points_transposed.get(), + rhs_scales_transposed.get(), + k, + rhs_qparams_stride); + rhs_zero_points = rhs_zero_points_transposed.get(); + rhs_scales = rhs_scales_transposed.get(); + } + + constexpr int nr = 16; + constexpr int kr = 4; + for (int m_idx = 0; m_idx < m; m_idx++) { + // Loop over 16 cols at a time + // Access to partial tiles must be protected:w + assert(n >= nr); + for (int n_idx = 0; n_idx < n; n_idx += nr) { + // If remaining is < nr, that must mean that (nr - remaining) items + // dont need to be computed. + // In order to avoid out-of-bounds access, we need to rewind n_indx a + // bit + // |-------------------|-------------------| + // 0-------------------8-------------------16 + // 0-------------------8-----10 + // If n = 10 and nr = 8 then at n_idx = 8, we need to rewind n_idx to + // 8 - (8 - 10) = 2 + int remaining = std::min(n - n_idx, nr); + n_idx = n_idx - (nr - remaining); + // Set activation_ptr to start of activation qvals for row m_idx + const float* lhs_ptr = lhs + m_idx * lhs_stride_m; + const int8_t* rhs_ptr = rhs + n_idx; + float32x4_t sums[nr / 4] = {vdupq_n_f32(0)}; + + // Loop k_idx by group + int k_idx = 0; + for (; (k_idx + kr) <= k; k_idx += kr) { + block_mul_1x16x4( + lhs_ptr, + rhs_ptr, + rhs_stride_n, + rhs_zero_points + k_idx, + rhs_scales + k_idx, + sums); + lhs_ptr += kr; + rhs_ptr += kr * rhs_stride_n; + } + + for (int ki = 0; ki < (k - k_idx); ++ki) { + // For each of the remaining k values + // Load 1 int8_t from lhs + // Load 16 int8_t from rhs + // And multiply + add into the 16 accumulators + // arranged as int32x4_t[4] + int8x16_t rhs_vec = vld1q_s8(rhs_ptr + ki * rhs_stride_n); + block_mul_1x16x1( + lhs_ptr[ki], + rhs_vec, + rhs_zero_points[k_idx + ki], + rhs_scales[k_idx + ki], + sums); + } + + // Store result + // Because we adjust n_idx, we may end up writing the same location + // twice + // Note that the reason this case is being handled only for this kernel + // and not others in this directory is because only for this kernel + // we support accumulation. + float* store_loc = output + m_idx * out_stride_m + n_idx; + if (remaining < 16) { + // If remaining is < 16, then not all of the 16 accumulators are + // valid. That is not all of float32x4_t[4] are valid. We need to + // find the first valid one, and then store the rest of the + // accumulators in the same order. + // First valid one is at 3 - ((remaining - 1) / 4) because: + // If remaining is say 10 then first 6 are not valid. + // Thus first group of 4 at sums[0] is not valid. + // In the second group of 4, the first 2 are not valid. + // Rest are valid. + int start_sum_idx = 3 - ((remaining - 1) / 4); + // If remaining is 11, then the sums[1] has 3 valid values + // so 3 - (11 -1) % 4 = 3 - 10 % 4 = 3 - 2 = 1 + // Thus there is 1 invalid value in the first group of 4 + int invalid_values_in_32x4_reg = 3 - (remaining - 1) % 4; + store_loc += start_sum_idx * 4; + store_loc += invalid_values_in_32x4_reg; + if (invalid_values_in_32x4_reg > 0) { + for (int val_idx = invalid_values_in_32x4_reg; val_idx < 4; + ++val_idx) { + *store_loc = sums[start_sum_idx][val_idx] + (*store_loc) * beta; + store_loc += 1; + } + start_sum_idx++; + } + for (int out_idx = 0, sum_idx = start_sum_idx; sum_idx < nr / 4; + out_idx += 4, ++sum_idx) { + float32x4_t sum_val = vld1q_f32(store_loc + out_idx); + sums[sum_idx] = vfmaq_n_f32(sums[sum_idx], sum_val, beta); + vst1q_f32(store_loc + out_idx, sums[sum_idx]); + } + } else { + for (int out_idx = 0, sum_idx = 0; out_idx < nr; + out_idx += 4, ++sum_idx) { + float32x4_t sum_val = vld1q_f32(store_loc + out_idx); + sums[sum_idx] = vfmaq_n_f32(sums[sum_idx], sum_val, beta); + vst1q_f32(store_loc + out_idx, sums[sum_idx]); + } + } + } // n_idx + } // m_idx + } +}; + +} // namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32::internal + +namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32 { +template +void kernel( + int m, + int n, + int k, + const float* lhs, + int lhs_stride_m, + const int8_t* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* rhs_zero_points, + const float* rhs_scales, + const float beta, + const int rhs_qparams_stride) { + torchao::kernels::cpu::aarch64::quantized_matmul:: + fp32_a_input_channelwise_8bit_b_1x16x4_f32::internal:: + KernelImpl::run( + m, + n, + k, + lhs, + lhs_stride_m, + rhs, + rhs_stride_n, + output, + out_stride_m, + rhs_zero_points, + rhs_scales, + beta, + rhs_qparams_stride); +} +} // namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32 +} // namespace torchao::kernels::cpu::aarch64::quantized_matmul + +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h b/torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h index 4005dee564..43f3dd4bce 100644 --- a/torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h +++ b/torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h @@ -66,9 +66,30 @@ void kernel( const int rhs_qparams_stride); } // namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal + +namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32 { + +template +void kernel( + int m, + int n, + int k, + const float* lhs, + int lhs_stride_m, + const int8_t* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* rhs_zero_points, + const float* rhs_scales, + const float beta, + const int rhs_qparams_stride); + +} // namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32 } // namespace torchao::kernels::cpu::aarch64::quantized_matmul #include #include +#include #endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp index 1b3e11156f..e7e2d09c64 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp @@ -226,4 +226,189 @@ TEST( /*m=*/4, /*k=*/37, /*n=*/19); } +class FP32A_QuantizedB_FP32C_Test : public ::testing::TestWithParam { + public: + int m; + int k; + int n; + int stride; + + bool rhs_has_zeros; + bool lhs_is_transposed; + bool rhs_is_transposed; + + std::vector init_output; + std::vector expected_output; + + std::vector lhs; + + std::vector rhs; + std::vector rhs_qvals; + std::vector rhs_scales; + std::vector rhs_zeros; + + void generate( + int m_, + int k_, + int n_, + bool rhs_has_zeros_, + bool lhs_is_transposed_, + bool rhs_is_transposed_, + int stride_ = 1) { + // Here stride is only applicable to rhs + // and it means that k elements are stride * napart for k x n matrix + // and stride apart for n x k matrix + assert(!lhs_is_transposed_); + assert(rhs_has_zeros_); + m = m_; + k = k_; + n = n_; + stride = stride_; + rhs_has_zeros = rhs_has_zeros_; + lhs_is_transposed = lhs_is_transposed_; + rhs_is_transposed = rhs_is_transposed_; + + assert(!rhs_is_transposed || stride == 1); + + // Generate activations + lhs = torchao::get_random_vector(m * k, -1.0, 1.0); + + // The strange thing this is doing is that instead of quantizing + // each output channel separately, we are quantizing each input channel + // Reason why we do !rhs_is_transposed is because + // we actually want k x n matrix not n x k matrix + // because each input channel is quantized separately + std::tie(rhs, rhs_qvals, rhs_scales, rhs_zeros) = + torchao::test_utils::generate_per_token_quantized_tensor( + k * stride, n, rhs_is_transposed); + + // Compute expected output + init_output = torchao::get_random_vector(m * n, -1.0, 1.0); + + assert(init_output.size() == m * n); + assert(lhs.size() == m * k); + assert(rhs.size() == n * stride * k); + assert(rhs_qvals.size() == n * stride * k); + assert(rhs_scales.size() == k * stride); + assert(rhs_zeros.size() == k * stride); + } + + void execute(float beta) { + // Compute expected output + expected_output = init_output; + + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + float res = 0.0; + for (int k_idx = 0; k_idx < k; k_idx++) { + int lhs_idx = m_idx * k + k_idx; + int rhs_idx = k_idx * stride * n + n_idx; + if (rhs_is_transposed) { + rhs_idx = n_idx * k * stride + k_idx * stride; + } + float rhs_dequant = rhs_scales[k_idx * stride] * + (static_cast(rhs_qvals[rhs_idx]) - + static_cast(rhs_zeros[k_idx * stride])); + + res += lhs[lhs_idx] * rhs_dequant; + } + expected_output[m_idx * n + n_idx] = + expected_output[m_idx * n + n_idx] * beta + res; + } + } + } + + float beta() const { + return GetParam(); + } +}; + +static void test_fp32_a_input_channelwise_8bit_b( + int m, + int k, + int n, + float beta, + FP32A_QuantizedB_FP32C_Test& test_case, + int stride = 1) { + test_case.execute(beta); + + using namespace torchao::kernels::cpu::aarch64::quantized_matmul:: + fp32_a_input_channelwise_8bit_b_1x16x4_f32; + + std::vector output(test_case.init_output); + kernel( + m, + n, + k, + test_case.lhs.data(), + k /*lhs_stride_m*/, + test_case.rhs_qvals.data(), + n * stride /*rhs_stride_n*/, + output.data(), + n /*out_stride_n*/, + test_case.rhs_zeros.data(), + test_case.rhs_scales.data(), + beta, + stride /*rhs qparams stride*/); + + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } +} + +TEST_P(FP32A_QuantizedB_FP32C_Test, BTranposedWithZeroPoints) { + generate(1, 128, 16, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/1, /*k=*/128, /*n=*/16, beta(), *this); +} + +TEST_P(FP32A_QuantizedB_FP32C_Test, BTranposedWithZeroPointsLargeM) { + generate(4, 128, 16, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/4, /*k=*/128, /*n=*/16, beta(), *this); +} + +TEST_P(FP32A_QuantizedB_FP32C_Test, BTranposedWithZeroPointsOddSizes) { + generate(4, 37, 24, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/4, /*k=*/37, /*n=*/24, beta(), *this); +} + +TEST_P(FP32A_QuantizedB_FP32C_Test, BTranposedWithZeroPointsOddSizes2) { + generate(4, 37, 19, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/4, /*k=*/37, /*n=*/19, beta(), *this); +} + +TEST_P(FP32A_QuantizedB_FP32C_Test, BTranposedWithZeroPointsOddSizes3) { + generate(4, 27, 21, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/4, /*k=*/27, /*n=*/21, beta(), *this); +} + +TEST_P(FP32A_QuantizedB_FP32C_Test, BTranposedWithZeroPointsAlpha) { + generate(1, 128, 16, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/1, /*k=*/128, /*n=*/16, beta(), *this); +} + +TEST_P(FP32A_QuantizedB_FP32C_Test, BTranposedWithZeroPointsWithStrides) { + stride = 5; + generate(1, 128, 16, true, false, false, stride); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/1, /*k=*/128, /*n=*/16, beta(), *this, stride); +} + +TEST_P(FP32A_QuantizedB_FP32C_Test, BTranposedWithZeroPointsOddSizes2Strides) { + stride = 11; + generate(7, 37, 19, true, false, false, stride); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/7, /*k=*/37, /*n=*/19, beta(), *this, stride); +} + +INSTANTIATE_TEST_SUITE_P( + F32AInt8BFP32CTest, + FP32A_QuantizedB_FP32C_Test, + ::testing::Values(0.0, 1.0, 2.69)); + #endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h index 80ddcb690d..e411211eb4 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h @@ -84,11 +84,9 @@ inline float get_float_from_bf16(uint16_t bf16) { return f; } -namespace { -auto generate_per_token_quantized_tensor( - int m, - int n, - bool transposed = false) { +namespace test_utils { +auto generate_per_token_quantized_tensor(int m, int n, bool transposed = false); +auto generate_per_token_quantized_tensor(int m, int n, bool transposed) { auto activations = get_random_vector(m * n, -1.0, 1.0); auto activation_qvals = std::vector(m * n, 0); auto activation_scales = std::vector(m, 0); @@ -135,7 +133,7 @@ auto generate_per_token_quantized_tensor( return std::make_tuple( activations, activation_qvals, activation_scales, activation_zeros); } -} // namespace +} // namespace test_utils struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { int m; @@ -236,7 +234,7 @@ struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { // Generate activations auto [activations, activation_qvals, activation_scales, activation_zeros] = - generate_per_token_quantized_tensor(m, k); + test_utils::generate_per_token_quantized_tensor(m, k); // Generate weights assert(k % weight_group_size == 0); @@ -441,10 +439,11 @@ struct channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case { assert(rhs_is_transposed || stride == 1); // Generate activations auto [lhs, lhs_qvals, lhs_scales, lhs_zeros] = - generate_per_token_quantized_tensor(m * stride, k); + test_utils::generate_per_token_quantized_tensor(m * stride, k); auto [rhs, rhs_qvals, rhs_scales, rhs_zeros] = - generate_per_token_quantized_tensor(n * stride, k, !rhs_is_transposed); + test_utils::generate_per_token_quantized_tensor( + n * stride, k, !rhs_is_transposed); // Above function produces nxk matrix and to produce kxn you need transposed // = true. we do !rhs_is_transposed becaues when rhs_is_transposed = true // the shape should be nxk instead of kxn. From 8f9bd0aeb300104a18f979d0f27ed9f31af365b4 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Wed, 2 Apr 2025 13:56:54 -0700 Subject: [PATCH 11/30] Add quantized q @ k test for intented used in quantized attention Differential Revision: D71370604 Pull Request resolved: https://github.com/pytorch/ao/pull/2006 --- .../cpu/aarch64/tests/test_qmatmul.cpp | 98 ++++++++ .../kernels/cpu/aarch64/tests/test_utils.h | 1 + .../tests/test_utils_quantized_attention.h | 235 ++++++++++++++++++ 3 files changed, 334 insertions(+) create mode 100644 torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp index e7e2d09c64..344b2c4915 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp @@ -12,6 +12,7 @@ #include #include #include +#include float kTol = 0.0001; @@ -411,4 +412,101 @@ INSTANTIATE_TEST_SUITE_P( FP32A_QuantizedB_FP32C_Test, ::testing::Values(0.0, 1.0, 2.69)); +static void test_8bit_per_token_q_at_k_matmul_attention( + int b, + int s_q, + int s_k, + int h, + int d, + bool transpose = true) { + auto test_case = torchao:: + channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case:: + generate(b, s_q, s_k, h, d, transpose); + + using namespace torchao::kernels::cpu::aarch64::quantized_matmul:: + channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot; + + size_t q_b_stride = test_case.b_q_stride; + size_t q_h_stride = test_case.h_q_stride; + size_t q_s_q_stride = test_case.s_q_stride; + size_t q_scale_zp_b_stride = test_case.b_q_qparams_stride; + size_t q_scale_zp_h_stride = test_case.h_q_qparams_stride; + size_t q_scale_zp_s_stride = test_case.s_q_qparams_stride; + + size_t k_b_stride = test_case.b_k_stride; + size_t k_h_stride = test_case.h_k_stride; + size_t k_s_k_stride = test_case.s_k_stride; + size_t k_scale_zp_b_stride = test_case.b_k_qparams_stride; + size_t k_scale_zp_h_stride = test_case.h_k_qparams_stride; + size_t k_scale_zp_s_stride = test_case.s_k_qparams_stride; + + std::vector output(b * h * s_q * s_k); + size_t output_b_stride = h * s_q * s_k; + size_t output_h_stride = s_q * s_k; + size_t output_s_q_stride = s_k; + + for (int b_idx = 0; b_idx < b; b_idx++) { + for (int h_idx = 0; h_idx < h; h_idx++) { + kernel( + s_q, + s_k, + d, + test_case.q_qvals.data() + b_idx * q_b_stride + h_idx * q_h_stride, + q_s_q_stride /*lhs_stride_m*/, + test_case.k_qvals.data() + b_idx * k_b_stride + h_idx * k_h_stride, + k_s_k_stride /*rhs_stride_n*/, + output.data() + b_idx * output_b_stride + h_idx * output_h_stride, + output_s_q_stride /*out_stride_n*/, + test_case.q_zeros.data() + b_idx * q_scale_zp_b_stride + + h_idx * q_scale_zp_h_stride, + test_case.k_zeros.data() + b_idx * k_scale_zp_b_stride + + h_idx * k_scale_zp_h_stride, + test_case.q_scales.data() + b_idx * q_scale_zp_b_stride + + h_idx * q_scale_zp_h_stride, + test_case.k_scales.data() + b_idx * k_scale_zp_b_stride + + h_idx * k_scale_zp_h_stride, + q_scale_zp_s_stride /*lhs qparams stride*/, + k_scale_zp_s_stride /*rhs qparams stride*/); + } + } + + for (int i = 0; i < b * h * s_q * s_k; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } +} + +TEST(test_8bit_per_token_q_at_k_matmul_attention, Basic) { + test_8bit_per_token_q_at_k_matmul_attention(1, 16, 16, 8, 16); +} + +TEST(test_8bit_per_token_q_at_k_matmul_attention, PrimeHeadsAndHeadDim) { + test_8bit_per_token_q_at_k_matmul_attention(1, 8, 8, 7, 33); +} + +TEST( + test_8bit_per_token_q_at_k_matmul_attention, + PrimeHeadsAndHeadDimDiffSqSk) { + test_8bit_per_token_q_at_k_matmul_attention(1, 7, 16, 7, 33); +} + +TEST(test_8bit_per_token_q_at_k_matmul_attention, PrimeHeadsAndSmallHeadDim) { + test_8bit_per_token_q_at_k_matmul_attention(1, 8, 8, 7, 3); +} + +TEST(test_8bit_per_token_q_at_k_matmul_attention, BasicNoTransposed) { + test_8bit_per_token_q_at_k_matmul_attention(1, 16, 16, 8, 16, false); +} + +TEST( + test_8bit_per_token_q_at_k_matmul_attention, + PrimeHeadsAndHeadDimDiffSqSkNoTranspose) { + test_8bit_per_token_q_at_k_matmul_attention(1, 7, 16, 7, 33, false); +} + +TEST( + test_8bit_per_token_q_at_k_matmul_attention, + PrimeHeadsAndSmallHeadDimNoTranspose) { + test_8bit_per_token_q_at_k_matmul_attention(1, 8, 8, 7, 3, false); +} + #endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h index e411211eb4..4f96f8bf96 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h @@ -86,6 +86,7 @@ inline float get_float_from_bf16(uint16_t bf16) { namespace test_utils { auto generate_per_token_quantized_tensor(int m, int n, bool transposed = false); + auto generate_per_token_quantized_tensor(int m, int n, bool transposed) { auto activations = get_random_vector(m * n, -1.0, 1.0); auto activation_qvals = std::vector(m * n, 0); diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h new file mode 100644 index 0000000000..9ca86ece76 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h @@ -0,0 +1,235 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include +#include +#include +#include +#include +#include + +namespace torchao { +struct channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case { + int b; + int s_q; + int s_k; + int h; + int d; + bool tranposed; + + size_t b_q_stride; + size_t h_q_stride; + size_t s_q_stride; + + size_t b_k_stride; + size_t h_k_stride; + size_t s_k_stride; + + size_t b_q_qparams_stride; + size_t h_q_qparams_stride; + size_t s_q_qparams_stride; + + size_t b_k_qparams_stride; + size_t h_k_qparams_stride; + size_t s_k_qparams_stride; + + std::vector expected_output; + + std::vector q; + std::vector q_qvals; + std::vector q_scales; + std::vector q_zeros; + + std::vector k; + std::vector k_qvals; + std::vector k_scales; + std::vector k_zeros; + + channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case( + int b_, + int s_q_, + int s_k_, + int h_, + int d_, + int transposed_, + size_t b_q_stride_, + size_t h_q_stride_, + size_t s_q_stride_, + size_t b_k_stride_, + size_t h_k_stride_, + size_t s_k_stride_, + size_t b_q_qparams_stride_, + size_t h_q_qparams_stride_, + size_t s_q_qparams_stride_, + size_t b_k_qparams_stride_, + size_t h_k_qparams_stride_, + size_t s_k_qparams_stride_, + std::vector expected_output_, + std::vector q_, + std::vector q_qvals_, + std::vector q_scales_, + std::vector q_zeros_, + std::vector k_, + std::vector k_qvals_, + std::vector k_scales_, + std::vector k_zeros_) + : b(b_), + s_q(s_q_), + s_k(s_k_), + h(h_), + d(d_), + tranposed(transposed_), + b_q_stride(b_q_stride_), + h_q_stride(h_q_stride_), + s_q_stride(s_q_stride_), + b_k_stride(b_k_stride_), + h_k_stride(h_k_stride_), + s_k_stride(s_k_stride_), + b_q_qparams_stride(b_q_qparams_stride_), + h_q_qparams_stride(h_q_qparams_stride_), + s_q_qparams_stride(s_q_qparams_stride_), + b_k_qparams_stride(b_k_qparams_stride_), + h_k_qparams_stride(h_k_qparams_stride_), + s_k_qparams_stride(s_k_qparams_stride_), + expected_output(expected_output_), + q(q_), + q_qvals(q_qvals_), + q_scales(q_scales_), + q_zeros(q_zeros_), + k(k_), + k_qvals(k_qvals_), + k_scales(k_scales_), + k_zeros(k_zeros_) { + assert(expected_output.size() == b * s_q * h * s_k); + assert(q.size() == b * s_q * h * d); + assert(q_qvals.size() == b * s_q * h * d); + assert(q_scales.size() == b * s_q * h); + assert(q_zeros.size() == b * s_q * h); + assert(k.size() == b * s_k * h * d); + assert(k_qvals.size() == b * s_k * h * d); + assert(k_scales.size() == b * s_k * h); + assert(k_zeros.size() == b * s_k * h); + } + + static channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case + generate(int b, int s_q, int s_k, int h, int d, bool transposed = true) { + // Generate activations + auto [lhs, lhs_qvals, lhs_scales, lhs_zeros] = + torchao::test_utils::generate_per_token_quantized_tensor( + b * s_q * h, d); + + auto [rhs, rhs_qvals, rhs_scales, rhs_zeros] = + torchao::test_utils::generate_per_token_quantized_tensor( + b * s_k * h, d); + // Above function produces nxk matrix and to produce kxn you need transposed + // = true. we do !rhs_is_transposed becaues when rhs_is_transposed = true + // the shape should be nxk instead of kxn. + + size_t b_q_stride = h * s_q * d; + size_t h_q_stride = s_q * d; + size_t s_q_stride = d; + + size_t b_k_stride = h * s_k * d; + size_t h_k_stride = s_k * d; + size_t s_k_stride = d; + + size_t b_q_qparams_stride = h * s_q; + size_t h_q_qparams_stride = s_q; + size_t s_q_qparams_stride = 1; + + size_t b_k_qparams_stride = h * s_k; + size_t h_k_qparams_stride = s_k; + size_t s_k_qparams_stride = 1; + + if (!transposed) { + h_q_stride = d; + s_q_stride = h * d; + h_k_stride = d; + s_k_stride = h * d; + + s_q_qparams_stride = h; + h_q_qparams_stride = 1; + + s_k_qparams_stride = h; + h_k_qparams_stride = 1; + } + + // Compute expected output + std::vector expected_output(b * h * s_q * s_k); + size_t b_out_stride = h * s_q * s_k; + size_t h_out_stride = s_q * s_k; + size_t s_q_out_stride = s_k; + + for (int b_idx = 0; b_idx < b; b_idx++) { + for (int s_q_idx = 0; s_q_idx < s_q; s_q_idx++) { + for (int h_idx = 0; h_idx < h; h_idx++) { + for (int s_k_idx = 0; s_k_idx < s_k; s_k_idx++) { + float res = 0.0; + for (int d_idx = 0; d_idx < d; d_idx++) { + int lhs_idx = b_idx * b_q_stride + s_q_idx * s_q_stride + + h_idx * h_q_stride + d_idx; + int rhs_idx = b_idx * b_k_stride + s_k_idx * s_k_stride + + h_idx * h_k_stride + d_idx; + int lhs_scales_zp_idx = b_idx * b_q_qparams_stride + + h_idx * h_q_qparams_stride + s_q_idx * s_q_qparams_stride; + int rhs_scales_zp_idx = b_idx * b_k_qparams_stride * h + + h_idx * h_k_qparams_stride + s_k_idx * s_k_qparams_stride; + float lhs_dequant = lhs_scales[lhs_scales_zp_idx] * + (lhs_qvals[lhs_idx] - lhs_zeros[lhs_scales_zp_idx]); + + float rhs_dequant = rhs_scales[rhs_scales_zp_idx] * + (rhs_qvals[rhs_idx] - rhs_zeros[rhs_scales_zp_idx]); + + res += lhs_dequant * rhs_dequant; + } + expected_output + [b_idx * b_out_stride + s_q_idx * s_q_out_stride + + h_idx * h_out_stride + s_k_idx] = res; + } + } + } + } + + // Return test case + return channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case( + b, + s_q, + s_k, + h, + d, + transposed, + b_q_stride, + h_q_stride, + s_q_stride, + b_k_stride, + h_k_stride, + s_k_stride, + b_q_qparams_stride, + h_q_qparams_stride, + s_q_qparams_stride, + b_k_qparams_stride, + h_k_qparams_stride, + s_k_qparams_stride, + expected_output, + lhs, + lhs_qvals, + lhs_scales, + lhs_zeros, + rhs, + rhs_qvals, + rhs_scales, + rhs_zeros); + } +}; + +} // namespace torchao + +#endif // defined(__aarch64__) || defined(__ARM_NEON) From 2f62e01a89900b93a6bf97c329ae8fcb87b37eb8 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 2 Apr 2025 15:57:03 -0700 Subject: [PATCH 12/30] Update version.txt (#2009) --- version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.txt b/version.txt index 78bc1abd14..d9df1bbc0c 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.10.0 +0.11.0 From 49705d90431dc7a9a5114612477d7933a64a7fe9 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 2 Apr 2025 16:20:53 -0700 Subject: [PATCH 13/30] Initial prototype of differentiable _scaled_grouped_mm function (#1969) --- .../prototype/scaled_grouped_mm/__init__.py | 3 + .../scaled_grouped_mm/scaled_grouped_mm.py | 361 ++++++++++++++++++ .../test_scaled_grouped_mm.py | 196 ++++++++++ 3 files changed, 560 insertions(+) create mode 100644 torchao/prototype/scaled_grouped_mm/__init__.py create mode 100644 torchao/prototype/scaled_grouped_mm/scaled_grouped_mm.py create mode 100644 torchao/prototype/scaled_grouped_mm/test_scaled_grouped_mm.py diff --git a/torchao/prototype/scaled_grouped_mm/__init__.py b/torchao/prototype/scaled_grouped_mm/__init__.py new file mode 100644 index 0000000000..9c6278884a --- /dev/null +++ b/torchao/prototype/scaled_grouped_mm/__init__.py @@ -0,0 +1,3 @@ +from torchao.prototype.scaled_grouped_mm.scaled_grouped_mm import ( + _scaled_grouped_mm as _scaled_grouped_mm, +) diff --git a/torchao/prototype/scaled_grouped_mm/scaled_grouped_mm.py b/torchao/prototype/scaled_grouped_mm/scaled_grouped_mm.py new file mode 100644 index 0000000000..a431288c07 --- /dev/null +++ b/torchao/prototype/scaled_grouped_mm/scaled_grouped_mm.py @@ -0,0 +1,361 @@ +from typing import Optional, Tuple + +import torch + +from torchao.float8.config import ScalingGranularity +from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated + + +def _scaled_grouped_mm( + A: torch.Tensor, + B_t: torch.Tensor, + offs: torch.Tensor, + out_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + """ + This function performs dynamic float8 quantization with row-wise scaling + on the input tensors A and B, then performs a scaled grouped GEMM and returns the results. + + Args: + A (bf16/float32 torch.Tensor): The first high-precision input tensor, which must be a 2D tensor of shape (M * num_groups, K) + and in row-major memory layout. + B_t (bf16/float32 torch.Tensor): The second high-precision input tensor which must be 3D, which must be shape (B, K, N) + and in column-major memory layout. + offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor. + out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported. + """ + return _Float8GroupedMM.apply( + A, + B_t, + offs, + out_dtype, + ) + + +class _Float8GroupedMM(torch.autograd.Function): + """Differentiable implementation of grouped GEMM with dynamic float8 quantization.""" + + @staticmethod + def forward( + ctx, + A: torch.Tensor, + B_t: torch.Tensor, + offs: torch.Tensor, + out_dtype: Optional[torch.dtype] = None, + ) -> torch.Tensor: + # torchao _scaled_grouped_mm only supports A=2D, B=3D. + assert A.ndim == 2, "A must be 2D" + assert B_t.ndim == 3, "B must be 3D" + + assert ( + A.size(-1) % 16 == 0 + ), f"A must have a last dim divisible by 16, but got shape: {A.shape}" + assert ( + B_t.size(-2) % 16 == 0 and B_t.size(-1) % 16 == 0 + ), f"B must have last 2 dims divisible by 16, but got shape: {B_t.shape}" + + # Assert input tensors are in high-precision dtypes. + assert ( + A.dtype == torch.float32 or A.dtype == torch.bfloat16 + ), "A must be float32 or bfloat16" + assert ( + B_t.dtype == torch.float32 or B_t.dtype == torch.bfloat16 + ), "B must be float32 or bfloat16" + assert offs.dtype == torch.int32, "offs must be int32" + + # Assert A and B dims are compatible for a scaled grouped GEMM. + assert A.size(-1) == B_t.size( + -2 + ), f"shape {A.shape} and {B_t.shape} are not compatible for _scaled_grouped_mm" + + # The left operand in the scaled grouped GEMM must be row-major due to hardware requirements. + assert not _is_column_major(A), "A must be row-major" + + # Due to hardware requirements, the right operand in a scaled grouped GEMM must be column-major. + assert _is_column_major(B_t), "B must be column-major" + + # Convert high precision input tensor to float8, row-major for left operand of grouped GEMM. + # A shape: (M, K) + # A_scales shape: (M,1) + A_scales = tensor_to_scale( + A, + torch.float8_e4m3fn, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=-1, + round_scales_to_power_of_2=True, + ) + A_scaled = A.to(torch.float32) * A_scales + A_fp8_row_major = to_fp8_saturated(A_scaled, torch.float8_e4m3fn) + + # Convert B to float8, column-major for right operand of grouped GEMM. + # B shape: (B, K, N) + # B scales must be computed rowwise keeping the outer/final dim, so: + # B_scales shape: (B, 1, N) + B_t_scales = tensor_to_scale( + B_t, + torch.float8_e4m3fn, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=-2, + round_scales_to_power_of_2=True, + ) + B_t_scaled = B_t.to(torch.float32) * B_t_scales + B_t_fp8_col_major = to_fp8_saturated(B_t_scaled, torch.float8_e4m3fn) + + # Precompute non-transposed B column-major for backward, to save memory by storing the + # low precision B tensor instead of the high precision B tensor. + # In the backward this is needed for grad_A: grad_output @ B. + B = B_t.contiguous().transpose(-2, -1) + + # - B shape: (B, K, N) + # - B scales must be computed rowwise keeping the outer/final dim, so: + # - B_scale shape: (B, 1, N) + B_scales = tensor_to_scale( + B, + torch.float8_e4m3fn, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=-2, + round_scales_to_power_of_2=True, + ) + B_scaled = B.to(torch.float32) * B_scales + B_fp8_col_major = to_fp8_saturated(B_scaled, torch.float8_e4m3fn) + + # Store what we need for backward. + ctx.save_for_backward(A, B_fp8_col_major, B_scales, offs) + ctx.out_dtype = out_dtype + + # Perform scaled grouped GEMM and return result. + # output shape: scaled grouped mm of (M,K) @ (B,K,N) = (M,N) + return torch._scaled_grouped_mm( + A_fp8_row_major, + B_t_fp8_col_major, + A_scales.squeeze().reciprocal(), + B_t_scales.squeeze().reciprocal(), + offs, + out_dtype=out_dtype, + use_fast_accum=True, + ) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + A, B_fp8_col_major, B_scales, offs = ctx.saved_tensors + out_dtype = ctx.out_dtype + + # Convert grad_output to float8, row-major for left operand of grouped GEMM + # needed for grad_A: grad_output @ B + # + # grad_output shape: (M, N) + # grad_output_scale shape: (M, 1) + grad_output_scales = tensor_to_scale( + grad_output, + torch.float8_e4m3fn, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=-1, + round_scales_to_power_of_2=True, + ) + grad_output_scaled = grad_output.to(torch.float32) * grad_output_scales + grad_output_fp8_row_major = to_fp8_saturated( + grad_output_scaled, torch.float8_e4m3fn + ) + + # Compute grad_A. + # + # grad_A = grad_output @ B + # grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K) + grad_A = torch._scaled_grouped_mm( + grad_output_fp8_row_major, + B_fp8_col_major, + grad_output_scales.squeeze().reciprocal(), + B_scales.squeeze().reciprocal(), + offs, + out_dtype=out_dtype, + use_fast_accum=True, + ) + + # Convert tranpose of grad_output to float8, row-major for left operand of grouped GEMM + # needed for grad_B: grad_output_t @ A + grad_output_t_row_major = grad_output.transpose(-2, -1).contiguous() + + # Convert A to float8, column-major for right operand of grouped GEMM: + # needed for grad_B: grad_output @ A + A_col_major = A.transpose(-2, -1).contiguous().transpose(-2, -1) + + # grad_B is a special case. both operands of the grouped gemm will be 2D with offsets determing the "groups." + # Compute scales for grad_output_t and A, which are both 2D tensors with offsets which define the "jagged" groups. + grad_output_t_fp8_row_major, grad_output_t_scales = ( + _to_2d_jagged_float8_tensor_rowwise( + grad_output_t_row_major, + offs, + target_dtype=torch.float8_e4m3fn, + round_scales_to_power_of_2=True, + ) + ) + A_fp8_col_major, A_scales = _to_2d_jagged_float8_tensor_colwise( + A_col_major, + offs, + target_dtype=torch.float8_e4m3fn, + round_scales_to_power_of_2=True, + ) + + # Compute grad_B = grad_output_t @ A. + # grad_B = grad_output_t @ A + # grad_B = (N,M) @ (M,K) = (N,K) + grad_B = torch._scaled_grouped_mm( + grad_output_t_fp8_row_major, + A_fp8_col_major, + grad_output_t_scales.reciprocal(), + A_scales.reciprocal(), + offs, + out_dtype=out_dtype, + use_fast_accum=True, + ) + return grad_A, grad_B.transpose(-2, -1), None, None, None, None + + +def _to_2d_jagged_float8_tensor_colwise( + A_col_major: torch.Tensor, + offs: torch.Tensor, + target_dtype: torch.dtype = torch.float8_e4m3fn, + round_scales_to_power_of_2: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + This function converts the 2D input tensor A to a jagged float8 tensor, + with scales computed along *logical columns* for each group individually, + where groups are determined based on the offsets. + + For the right operand of a normal scaled GEMM, the rowwise scales are computed over logical columns. + (i.e., a tensor of (K,N) will have scales of shape (1,N). + + However, for a 2D right operand of a grouped GEMM, these logical columns go through multiple distinct + groups/subtensors, for which we want to compute scales individually. So we cannot take one set of scales + along the logical columns and apply it to the entire tensor. + + Instead, we need to compute scales for each subtensor individually. For a tensor of shape (K,N) this results + in scales of shape (1,N * num_groups). + + Args: + A (torch.Tensor): The input tensor to be converted to a jagged float8 tensor. + + Returns: + A tuple containing the jagged float8 tensor and the scales used for the conversion. + """ + assert A_col_major.ndim == 2, "A must be 2D" + + num_groups = offs.numel() + A_fp8_col_major = torch.empty_like(A_col_major, dtype=target_dtype) + A_scales = torch.empty( + A_fp8_col_major.size(1) * num_groups, + dtype=torch.float32, + device=A_fp8_col_major.device, + ) + + start_idx = 0 + next_scale_idx = 0 + for end_idx in offs.tolist(): + # Get the subtensor of A for this group, fetching the next group of rows, with all columns for each. + subtensor = A_col_major[start_idx:end_idx, :] # (local_group_size, K) + + # Compute local rowwise scales for this subtensor, which are along logical columns for the right operand. + subtensor_scales = tensor_to_scale( + subtensor, + target_dtype, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=0, + round_scales_to_power_of_2=round_scales_to_power_of_2, + ) + + # Apply scales to subtensor and convert to float8. + tensor_scaled = subtensor.to(torch.float32) * subtensor_scales + float8_subtensor = to_fp8_saturated(tensor_scaled, target_dtype) + + # Store this portion of the resulting float8 tensor and scales. + A_fp8_col_major[start_idx:end_idx, :] = float8_subtensor + A_scales[next_scale_idx : next_scale_idx + subtensor_scales.numel()] = ( + subtensor_scales.squeeze() + ) + + # Update start index for next group. + start_idx = end_idx + next_scale_idx += subtensor_scales.numel() + + return A_fp8_col_major, A_scales + + +def _to_2d_jagged_float8_tensor_rowwise( + x: torch.Tensor, + offs: torch.Tensor, + target_dtype: torch.dtype, + round_scales_to_power_of_2: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + This function converts the 2D input tensor to a jagged float8 tensor, + with scales computed along *logical rows* for each group individually, + where groups are determined based on the offsets. + + For a 2D *left* operand of a normal scaled GEMM, the rowwise scales are computed over logical rows. + (i.e., a tensor of (M,K) will have scales of shape (M,1). + + However, for a 2D left operand of a grouped GEMM, these logical rows go through multiple distinct + groups/subtensors, for which we want to compute scales individually. So we cannot take one set of scales + along the logical rows and apply it to the entire tensor. + + Instead, we need to compute scales for each subtensor individually. For a tensor of shape (M,K) this results + in scales of shape (M * num_groups, 1). + + Args: + A (torch.Tensor): The input tensor to be converted to a jagged float8 tensor. + + Returns: + A tuple containing the jagged float8 tensor and the scales used for the conversion. + """ + assert x.ndim == 2, "input tensor must be 2D" + + num_groups = offs.numel() + x_fp8 = torch.empty_like(x, dtype=target_dtype) + x_scales = torch.empty( + x_fp8.size(0) * num_groups, dtype=torch.float32, device=x_fp8.device + ) + + start_idx = 0 + next_scale_idx = 0 + for end_idx in offs.tolist(): + # Get the subtensor of A for this group, fetching all rows with the next group of rows. + subtensor = x[:, start_idx:end_idx] # (M, local_group_size) + + # Compute local rowwise scales for this subtensor, which are along logical rows for the left operand. + subtensor_scales = tensor_to_scale( + subtensor, + target_dtype, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=-1, + round_scales_to_power_of_2=round_scales_to_power_of_2, + ) + + # Apply scales to subtensor and convert to float8. + tensor_scaled = subtensor.to(torch.float32) * subtensor_scales + float8_subtensor = to_fp8_saturated(tensor_scaled, target_dtype) + + # Store this portion of the resulting float8 tensor and scales. + x_fp8[:, start_idx:end_idx] = float8_subtensor + x_scales[next_scale_idx : next_scale_idx + subtensor_scales.numel()] = ( + subtensor_scales.squeeze() + ) + + # Update start index for next group. + start_idx = end_idx + next_scale_idx += subtensor_scales.numel() + + return x_fp8, x_scales + + +def _is_column_major(x: torch.Tensor) -> bool: + """ + This function checks if the input tensor is column-major. + + Args: + x (torch.Tensor): The input tensor to be checked. + + Returns: + A boolean indicating whether the input tensor is column-major. + """ + assert x.ndim == 2 or x.ndim == 3, "input tensor must be 2D or 3D" + return x.stride(-2) == 1 and x.stride(-1) > 1 diff --git a/torchao/prototype/scaled_grouped_mm/test_scaled_grouped_mm.py b/torchao/prototype/scaled_grouped_mm/test_scaled_grouped_mm.py new file mode 100644 index 0000000000..cd347c3d9d --- /dev/null +++ b/torchao/prototype/scaled_grouped_mm/test_scaled_grouped_mm.py @@ -0,0 +1,196 @@ +import pytest +import torch + +from torchao.float8.config import ( + Float8LinearConfig, + Float8LinearRecipeName, +) +from torchao.float8.float8_linear import matmul_with_hp_or_float8_args +from torchao.float8.float8_tensor import LinearMMConfig +from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated +from torchao.prototype.scaled_grouped_mm.scaled_grouped_mm import ( + _scaled_grouped_mm, +) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_valid_scaled_grouped_mm_2d_3d(): + out_dtype = torch.bfloat16 + device = "cuda" + m, n, k, n_groups = 16, 32, 16, 4 + a = torch.randn( + m * n_groups, + k, + device=device, + requires_grad=True, + dtype=torch.bfloat16, + ) + b = torch.randn( + n_groups, + n, + k, + device=device, + dtype=torch.bfloat16, + ) + offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32) + + # b must be transposed and in column major format. + b_t = b.contiguous().transpose(-2, -1).requires_grad_(True) + + # Compute output. + out = _scaled_grouped_mm( + a, + b_t, + offs=offs, + out_dtype=out_dtype, + ) + + # Validate result. + ref_a = a.detach().clone().requires_grad_(True) + ref_b_t = b_t.detach().clone().requires_grad_(True) + ref_out = compute_reference_forward( + out, + ref_a, + ref_b_t, + n_groups, + out_dtype, + offs, + ) + assert torch.equal(out, ref_out) + + # Run backward pass. + out.sum().backward() + ref_out.sum().backward() + + # Validate gradients. + assert torch.equal(a.grad, ref_a.grad) + assert torch.equal(b_t.grad, ref_b_t.grad) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("m", [16, 17]) +@pytest.mark.parametrize("k", [16, 18]) +@pytest.mark.parametrize("n", [32, 33]) +def test_K_or_N_dim_not_multiple_of_16(m, n, k): + # - Leading dim of A doesn't have to be divisible by 16, since it will be + # divided up into groups based on offset anyway. + # - Trailing dim of A must be divisible by 16. + # - Leading dim of B (n_groups) doesn't need to be divisible by 16. + # - Last 2 dims of B must be divisible by 16. + if n % 16 == 0 and k % 16 == 0: + return + out_dtype = torch.bfloat16 + device = "cuda" + n_groups = 4 + a = torch.randn( + m * n_groups, + k, + device=device, + requires_grad=True, + dtype=torch.bfloat16, + ) + b = torch.randn( + n_groups, + n, + k, + device=device, + requires_grad=True, + dtype=torch.bfloat16, + ) + + # b must be transposed and in column major format. + b_t = b.transpose(-2, -1) + b_t = b_t.transpose(-2, -1).contiguous().transpose(-2, -1) + + offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32) + + # Compute output. + with pytest.raises(AssertionError): + _scaled_grouped_mm( + a, + b_t, + offs=offs, + out_dtype=out_dtype, + ) + + +def compute_reference_forward( + result: torch.Tensor, + A: torch.Tensor, + B_t: torch.Tensor, + n_groups: int, + out_dtype: torch.dtype, + offs: torch.Tensor, +): + assert result.dtype == out_dtype + + # Use official rowwise recipe as reference to ensure implementation is correct. + float8_config = Float8LinearConfig.from_recipe_name(Float8LinearRecipeName.ROWWISE) + + # Convert A to fp8. + A_scales = tensor_to_scale( + A, + float8_config.cast_config_input.target_dtype, + scaling_granularity=float8_config.cast_config_input.scaling_granularity, + axiswise_dim=-1, + round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2, + ) + A_scaled = A.to(torch.float32) * A_scales + A_fp8 = to_fp8_saturated(A_scaled, torch.float8_e4m3fn) + + # Convert B^t to fp8. + B_t_scales = tensor_to_scale( + B_t, + float8_config.cast_config_weight.target_dtype, + scaling_granularity=float8_config.cast_config_weight.scaling_granularity, + axiswise_dim=-2, + round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2, + ) + B_t_scaled = B_t.to(torch.float32) * B_t_scales + B_t_fp8 = to_fp8_saturated( + B_t_scaled, + torch.float8_e4m3fn, + ) + + # Split A and result into chunks, one for each group. + offs_cpu = offs.cpu() + A_list, A_list_fp8, A_scale_list, result_list = [], [], [], [] + start = 0 + for i in range(n_groups): + A_list.append(A[start : offs_cpu[i]]) + A_list_fp8.append(A_fp8[start : offs_cpu[i]]) + A_scale_list.append(A_scales[start : offs_cpu[i]]) + result_list.append(result[start : offs_cpu[i]]) + start = offs_cpu[i] + + # Validate each actual result group from the _scaled_grouped_mm is equal to: + # 1. A manual _scaled_mm for the group. + # 2. A matmul_with_hp_or_float8_args for the group (which is differentiable, and thus used to validate gradients). + outputs = [] + list1 = list(zip(A_list_fp8, B_t_fp8, A_scale_list, B_t_scales, result_list)) + list2 = list(zip(A_list, B_t, result_list)) + for i in range(len(list1)): + a1, b1, a1scale, b1scale, result1 = list1[i] + ref_group_result1 = torch._scaled_mm( + a1, + b1, + a1scale.reciprocal(), + b1scale.reciprocal(), + out_dtype=out_dtype, + bias=None, + use_fast_accum=float8_config.gemm_config_output.use_fast_accum, + ) + a2, b2, result2 = list2[i] + ref_group_result2 = matmul_with_hp_or_float8_args.apply( + a2, + b2, + LinearMMConfig(), + float8_config, + ) + assert torch.equal(result1, ref_group_result1) + assert torch.equal(result2, ref_group_result2) + outputs.append(ref_group_result2) + + # Concatenate the outputs and verify the full result is correct. + output_ref = torch.cat(outputs, dim=0) + return output_ref From 71a3d960b1f9831aa0fefff6bd71f1f3bc8ab109 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Wed, 2 Apr 2025 17:17:56 -0700 Subject: [PATCH 14/30] Add quantized attn_scores @ v test for intented used in quantized attention Differential Revision: D71370603 Pull Request resolved: https://github.com/pytorch/ao/pull/2008 --- ...input_channelwise_8bit_b_1x16x4_f32_impl.h | 6 + .../cpu/aarch64/tests/test_qmatmul.cpp | 87 +++++++++ .../tests/test_utils_quantized_attention.h | 168 ++++++++++++++++++ 3 files changed, 261 insertions(+) diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h b/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h index 389abb32a5..bdad1b4a47 100644 --- a/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h +++ b/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h @@ -101,6 +101,12 @@ struct KernelImpl { const int rhs_qparams_stride); }; +/* +Document param meaning +rhs_stride_n: Since rhs transposed == false, the expected shape of rhs is k x n. +Thus rhs_stride_n is the stride of k dim, that how many bytes aparts elements +in k dim are. +*/ template <> struct KernelImpl { static void run( diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp index 344b2c4915..05dbf13aac 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp @@ -509,4 +509,91 @@ TEST( test_8bit_per_token_q_at_k_matmul_attention(1, 8, 8, 7, 3, false); } +static void test_fp32_attn_scores_at_v_matmul_attention( + int b, + int s_attn, + int s_v, + int h, + int d, + bool transpose_v = true) { + auto test_case = + torchao::fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case::generate( + b, s_attn, s_v, h, d, transpose_v); + + using namespace torchao::kernels::cpu::aarch64::quantized_matmul:: + fp32_a_input_channelwise_8bit_b_1x16x4_f32; + + size_t attn_b_stride = test_case.b_attn_stride; + size_t attn_h_stride = test_case.h_attn_stride; + size_t attn_s_q_stride = test_case.s_attn_stride; + + size_t v_b_stride = test_case.b_v_stride; + size_t v_h_stride = test_case.h_v_stride; + size_t v_s_v_stride = test_case.s_v_stride; + size_t v_scale_zp_b_stride = test_case.b_v_qparams_stride; + size_t v_scale_zp_h_stride = test_case.h_v_qparams_stride; + size_t v_scale_zp_s_stride = test_case.s_v_qparams_stride; + + std::vector output(b * s_attn * h * d); + size_t output_b_stride = s_attn * h * d; + size_t output_s_attn_stride = h * d; + size_t output_h_stride = d; + + for (int b_idx = 0; b_idx < b; b_idx++) { + for (int h_idx = 0; h_idx < h; h_idx++) { + kernel( + s_attn, + d, + s_v, + test_case.attn_scores.data() + b_idx * attn_b_stride + + h_idx * attn_h_stride, + attn_s_q_stride /*lhs_stride_m*/, + test_case.v_qvals.data() + b_idx * v_b_stride + h_idx * v_h_stride, + v_s_v_stride /*rhs_stride_n*/, + output.data() + b_idx * output_b_stride + h_idx * output_h_stride, + output_s_attn_stride /*out_stride_n*/, + test_case.v_zeros.data() + b_idx * v_scale_zp_b_stride + + h_idx * v_scale_zp_h_stride, + test_case.v_scales.data() + b_idx * v_scale_zp_b_stride + + h_idx * v_scale_zp_h_stride, + 0.0 /*beta*/, + v_scale_zp_s_stride /*rhs qparams stride*/); + } + } + + for (int i = 0; i < b * s_attn * h * d; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } +} + +TEST(test_fp32_attn_scores_at_v_matmul_attention, Basic) { + test_fp32_attn_scores_at_v_matmul_attention(1, 16, 16, 8, 16); +} + +TEST(test_fp32_attn_scores_at_v_matmul_attention, PrimeHeadsAndHeadDim) { + test_fp32_attn_scores_at_v_matmul_attention(1, 8, 8, 7, 33); +} + +TEST(test_fp32_attn_scores_at_v_matmul_attention, PrimeSequenceDim) { + test_fp32_attn_scores_at_v_matmul_attention(1, 7, 9, 7, 33); +} + +TEST(test_fp32_attn_scores_at_v_matmul_attention, PrimeHeadsAndSmallHeadDim) { + test_fp32_attn_scores_at_v_matmul_attention(1, 8, 8, 7, 17); +} + +TEST(test_fp32_attn_scores_at_v_matmul_attention, BasicNoTranspose) { + test_fp32_attn_scores_at_v_matmul_attention(1, 16, 16, 8, 16, false); +} + +TEST( + test_fp32_attn_scores_at_v_matmul_attention, + PrimeHeadsAndSmallHeadDimNoTranspose) { + test_fp32_attn_scores_at_v_matmul_attention(1, 8, 8, 7, 17, false); +} + +TEST(test_fp32_attn_scores_at_v_matmul_attention, PrimeSequenceDimNoTranspose) { + test_fp32_attn_scores_at_v_matmul_attention(1, 7, 9, 7, 33, false); +} + #endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h index 9ca86ece76..52fb0851bc 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h @@ -230,6 +230,174 @@ struct channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case { } }; +struct fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case { + int b; + int s_attn; + int s_v; + int h; + int d; + size_t b_attn_stride; + size_t h_attn_stride; + size_t s_attn_stride; + size_t b_v_stride; + size_t h_v_stride; + size_t s_v_stride; + size_t b_v_qparams_stride; + size_t h_v_qparams_stride; + size_t s_v_qparams_stride; + + std::vector expected_output; + + std::vector attn_scores; + + std::vector v; + std::vector v_qvals; + std::vector v_scales; + std::vector v_zeros; + + fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case( + int b_, + int s_attn_, + int s_v_, + int h_, + int d_, + size_t b_attn_stride_, + size_t h_attn_stride_, + size_t s_attn_stride_, + size_t b_v_stride_, + size_t h_v_stride_, + size_t s_v_stride_, + size_t b_v_qparams_stride_, + size_t h_v_qparams_stride_, + size_t s_v_qparams_stride_, + std::vector expected_output_, + std::vector attn_scores_, + std::vector v_, + std::vector v_qvals_, + std::vector v_scales_, + std::vector v_zeros_) + : b(b_), + s_attn(s_attn_), + s_v(s_v_), + h(h_), + d(d_), + b_attn_stride(b_attn_stride_), + h_attn_stride(h_attn_stride_), + s_attn_stride(s_attn_stride_), + b_v_stride(b_v_stride_), + h_v_stride(h_v_stride_), + s_v_stride(s_v_stride_), + b_v_qparams_stride(b_v_qparams_stride_), + h_v_qparams_stride(h_v_qparams_stride_), + s_v_qparams_stride(s_v_qparams_stride_), + expected_output(expected_output_), + attn_scores(attn_scores_), + v(v_), + v_qvals(v_qvals_), + v_scales(v_scales_), + v_zeros(v_zeros_) { + assert(expected_output.size() == b * s_attn * h * d); + assert(attn_scores.size() == b * h * s_attn * s_v); + assert(v.size() == b * h * s_v * d); + assert(v_qvals.size() == b * h * s_v * d); + assert(v_scales.size() == b * h * s_v); + assert(v_zeros.size() == b * h * s_v); + } + + static fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case + generate(int b, int s_attn, int s_v, int h, int d, bool transposed_v = true) { + // Generate activations + auto lhs = get_random_vector(b * h * s_attn * s_v, -1.0, 1.0); + + auto [rhs, rhs_qvals, rhs_scales, rhs_zeros] = + torchao::test_utils::generate_per_token_quantized_tensor( + b * h * s_v, d); + // Above function produces nxk matrix and to produce kxn you need transposed + // = true. we do !rhs_is_transposed becaues when rhs_is_transposed = true + // the shape should be nxk instead of kxn. + + size_t b_attn_stride = h * s_attn * s_v; + size_t h_attn_stride = s_attn * s_v; + size_t s_attn_stride = s_v; + + size_t b_v_stride = h * s_v * d; + size_t h_v_stride = s_v * d; + size_t s_v_stride = d; + + size_t b_v_qparams_stride = h * s_v; + size_t h_v_qparams_stride = s_v; + size_t s_v_qparams_stride = 1; + + if (!transposed_v) { + h_v_stride = d; + s_v_stride = h * d; + + s_v_qparams_stride = h; + h_v_qparams_stride = 1; + } + + // Compute expected output + // Note that while the inputs can be in shape b x h x s_attn x s_v, + // and b x h x s_v x d the output is not in b x h x s_attn x s_v + // but rather b x s_attn x h x d. This is because the output of + // SDPA will normally be in b x h x s_attn x d, but we want to + // avoid any tranposes. Thus just aim to output in b x s_attn x h x d + // This is just for testing purposes. Kernel can actually write output + // in [B, H, S, D] if needed. + std::vector expected_output(b * s_attn * h * d); + size_t b_out_stride = s_attn * h * d; + size_t s_attn_out_stride = h * d; + size_t h_out_stride = d; + + for (int b_idx = 0; b_idx < b; b_idx++) { + for (int s_attn_idx = 0; s_attn_idx < s_attn; s_attn_idx++) { + for (int h_idx = 0; h_idx < h; h_idx++) { + for (int d_idx = 0; d_idx < d; d_idx++) { + float res = 0.0; + for (int s_v_idx = 0; s_v_idx < s_v; s_v_idx++) { + int lhs_idx = b_idx * b_attn_stride + s_attn_idx * s_attn_stride + + h_idx * h_attn_stride + s_v_idx; + int rhs_idx = b_idx * b_v_stride + h_idx * h_v_stride + d_idx + + s_v_idx * s_v_stride; + int rhs_scales_zp_idx = b_idx * b_v_qparams_stride + + h_idx * h_v_qparams_stride + s_v_idx * s_v_qparams_stride; + float rhs_dequant = rhs_scales[rhs_scales_zp_idx] * + (rhs_qvals[rhs_idx] - rhs_zeros[rhs_scales_zp_idx]); + + res += lhs[lhs_idx] * rhs_dequant; + } + expected_output + [b_idx * b_out_stride + s_attn_idx * s_attn_out_stride + + h_idx * h_out_stride + d_idx] = res; + } + } + } + } + + // Return test case + return fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case( + b, + s_attn, + s_v, + h, + d, + b_attn_stride, + h_attn_stride, + s_attn_stride, + b_v_stride, + h_v_stride, + s_v_stride, + b_v_qparams_stride, + h_v_qparams_stride, + s_v_qparams_stride, + expected_output, + lhs, + rhs, + rhs_qvals, + rhs_scales, + rhs_zeros); + } +}; } // namespace torchao #endif // defined(__aarch64__) || defined(__ARM_NEON) From 50d133aa7394aa498ad16f8dad6175aef3a8bdac Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Thu, 3 Apr 2025 09:08:58 -0700 Subject: [PATCH 15/30] add fallback kernel and interface Differential Revision: D71370598 Pull Request resolved: https://github.com/pytorch/ao/pull/2010 --- .../cpu/aarch64/tests/test_qmatmul.cpp | 1 + .../channelwise_8bit_a_channelwise_8bit_b.h | 133 ++++++ .../kernels/cpu/interface/quantized_matmul.h | 88 ++++ .../cpu/interface/test_qmatmul_interface.cpp | 448 ++++++++++++++++++ 4 files changed, 670 insertions(+) create mode 100644 torchao/experimental/kernels/cpu/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h create mode 100644 torchao/experimental/kernels/cpu/interface/quantized_matmul.h create mode 100644 torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp index 05dbf13aac..ff4f915b2d 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp @@ -70,6 +70,7 @@ struct test_channelwise_8bit_channelwise_8bit_b< false, false> { static void Run(int m, int k, int n, int stride = 1) { + // TODO: make use of stride for this kernel auto test_case = torchao::channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case:: generate(m, k, n, a_has_zeros, a_has_zeros, false, false); diff --git a/torchao/experimental/kernels/cpu/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h b/torchao/experimental/kernels/cpu/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h new file mode 100644 index 0000000000..3b070eb2b3 --- /dev/null +++ b/torchao/experimental/kernels/cpu/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h @@ -0,0 +1,133 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +namespace torchao::kernels::cpu::fallback::quantized_matmul { +namespace channelwise_8bit_a_channelwise_8bit_b::internal { + +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_tranposed> +struct KernelImpl { + static void run( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride); +}; + +template +struct KernelImpl { + static void run( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride) { + const int8_t* lhs_qvals = static_cast(lhs); + const int8_t* rhs_qvals = static_cast(rhs); + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + float res = 0.0; + for (int k_idx = 0; k_idx < k; k_idx++) { + int lhs_idx = m_idx * lhs_stride_m + k_idx; + int rhs_idx = k_idx * rhs_stride_n + n_idx; + if (b_transposed) { + rhs_idx = n_idx * rhs_stride_n + k_idx; + } + + float lhs_dequant = lhs_scales[m_idx * lhs_qparams_stride] * + (static_cast(lhs_qvals[lhs_idx]) - + static_cast( + lhs_zero_points[m_idx * lhs_qparams_stride])); + + float rhs_dequant = rhs_scales[n_idx * rhs_qparams_stride] * + (static_cast(rhs_qvals[rhs_idx]) - + static_cast( + rhs_zero_points[n_idx * rhs_qparams_stride])); + + res += lhs_dequant * rhs_dequant; + } + output[m_idx * n + n_idx] = res; + } + } + } +}; + +} // namespace + // channelwise_8bit_a_channelwise_8bit_b::internal +} // namespace torchao::kernels::cpu::fallback::quantized_matmul + +// TODO: Remove all ::kernels. No need for extra namespace. +namespace torchao::kernels::cpu::fallback::quantized_matmul { +namespace channelwise_8bit_a_channelwise_8bit_b { +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_transposed> +void kernel( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride) { + channelwise_8bit_a_channelwise_8bit_b::internal:: + KernelImpl::run( + m, + n, + k, + lhs, + lhs_stride_m, + rhs, + rhs_stride_n, + output, + out_stride_m, + lhs_zero_points, + rhs_zero_points, + lhs_scales, + rhs_scales, + lhs_qparams_stride, + rhs_qparams_stride); +} +} // namespace channelwise_8bit_a_channelwise_8bit_b +} // namespace torchao::kernels::cpu::fallback::quantized_matmul diff --git a/torchao/experimental/kernels/cpu/interface/quantized_matmul.h b/torchao/experimental/kernels/cpu/interface/quantized_matmul.h new file mode 100644 index 0000000000..01a4c704c5 --- /dev/null +++ b/torchao/experimental/kernels/cpu/interface/quantized_matmul.h @@ -0,0 +1,88 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +#include +#if defined(__aarch64__) || defined(__ARM_NEON) +#include +#include +#include +#endif // defined(__aarch64__) || defined(__ARM_NEON) + +namespace torchao::kernels::cpu::quantized_matmul { + +/* +a_stride_m: stride of a in memory to indiciate how far apart each row is. +b_stride_n: stride of b in memory to indiciate how far apart each row is. +If b is transposed (n x k), then this is how many bytes to skip to get to the +next row. If b is not transposed (k x n), then this is how many bytes to skip to +get to the next row. + +It also returns the stride of a and b, that should be used in the kernel. + +Will need to think of a better way to find the right +ukernel. Perhaps via ukernelconfig + registry?. +*/ +using int8_a_int8_b_channelwise_fp32_c_qmatmul_type = void (*)( + int, + int, + int, + const void*, + int, + const void*, + int, + float*, + int, + const int8_t*, + const int8_t*, + const float*, + const float*, + const int, + const int); + +int8_a_int8_b_channelwise_fp32_c_qmatmul_type +get_int8_a_int8_b_channelwise_qmatmul( + int m, + int n, + int k, + bool a_transposed, + bool b_transposed, + int& a_stride_m, + int& b_stride_n); + +int8_a_int8_b_channelwise_fp32_c_qmatmul_type +get_int8_a_int8_b_channelwise_qmatmul( + int m, + int n, + int k, + bool a_transposed, + bool b_transposed, + int& a_stride_m, + int& b_stride_n) { +#if defined(__aarch64__) || defined(__ARM_NEON) + if (!a_transposed && b_transposed && n >= 8) { + a_stride_m = k; + b_stride_n = k; + return aarch64::quantized_matmul:: + channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot:: + kernel; + } +#endif // defined(__aarch64__) || defined(__ARM_NEON) + assert(!a_transposed); + if (b_transposed) { + a_stride_m = k; + b_stride_n = k; + return torchao::kernels::cpu::fallback::quantized_matmul:: + channelwise_8bit_a_channelwise_8bit_b::kernel; + } else { + return torchao::kernels::cpu::fallback::quantized_matmul:: + channelwise_8bit_a_channelwise_8bit_b::kernel; + } +} +} // namespace torchao::kernels::cpu::quantized_matmul diff --git a/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp b/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp new file mode 100644 index 0000000000..3629f0960b --- /dev/null +++ b/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp @@ -0,0 +1,448 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include + +#include +#include + +float kTol = 0.0001; + +// This is unfortunately had to be copied over because code in test_utils.h +// depends on quantization kernels which are only buildable for ARM. +// I would like the testing code in this folder to be independent of the arch. +namespace { +void get_qvals_range(int& qmin, int& qmax, int nbit, bool is_symmetric) { + if (is_symmetric) { + qmin = -(1 << (nbit - 1)) + 1; + qmax = -qmin; + } else { + qmin = -(1 << (nbit - 1)); + qmax = (1 << (nbit - 1)) - 1; + } +} + +void get_scale_and_zero( + float& scale, + int& zero, + float vmin, + float vmax, + int qmin, + int qmax) { + assert(qmin < qmax); + assert(vmin < vmax); + scale = (vmax - vmin) / (qmax - qmin); + zero = qmin - std::round(vmin / scale); +} + +inline std::vector +get_random_vector(int size, float min = -1.0, float max = 1.0) { + assert(min < max); + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto dist = std::bind(std::uniform_real_distribution(min, max), rng); + std::vector res(size); + std::generate(res.begin(), res.end(), std::ref(dist)); + return res; +} + +void quantize( + // Output + int8_t* qvals, + // Inputs + const float* vals, + int size, + float scale, + int8_t zero, + int8_t qmin, + int8_t qmax) { + float invScale = 1.0 / (scale + 1e-16); + int i = 0; + auto curr_rounding_mode = fegetround(); + fesetround(FE_TONEAREST); + for (; i < size; ++i) { + // Quantize remaining elements using scalar code + float val = vals[i]; + float qval_f32 = zero + val * invScale; + int32_t qval_s32 = static_cast(std::nearbyint(qval_f32)); + + // Clip to qmin and qmax + qval_s32 = std::max( + static_cast(qmin), + std::min(qval_s32, static_cast(qmax))); + + // Store the quantized value + qvals[i] = static_cast(qval_s32); + } + fesetround(int(curr_rounding_mode)); +} + +auto generate_per_token_quantized_tensor( + int m, + int n, + bool transposed = false) { + auto activations = get_random_vector(m * n, -1.0, 1.0); + auto activation_qvals = std::vector(m * n, 0); + auto activation_scales = std::vector(m, 0); + auto activation_zeros = std::vector(m, 0); + + // Quantize activations with 8-bit asymmetric + // TODO: replace with generic function that does not use aarch64 + // quantize method after we combine with torchao + int qmin, qmax, zero; + float vmin, vmax, scale; + get_qvals_range(qmin, qmax, /*nbit=*/8, /*is_symmetric=*/false); + for (int m_idx = 0; m_idx < m; m_idx++) { + auto minmax = std::minmax_element( + activations.data() + m_idx * n, activations.data() + (m_idx + 1) * n); + vmin = *minmax.first; + vmax = *minmax.second; + get_scale_and_zero(scale, zero, vmin, vmax, qmin, qmax); + activation_scales[m_idx] = scale; + activation_zeros[m_idx] = zero; + quantize( + /*qvals=*/activation_qvals.data() + m_idx * n, + /*vals=*/activations.data() + m_idx * n, + /*size=*/n, + scale, + zero, + qmin, + qmax); + } + + if (transposed) { + auto activations_t = std::vector(m * n, 0); + auto activation_qvals_t = std::vector(m * n, 0); + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + int activation_idx = m_idx * n + n_idx; + int tranposed_idx = n_idx * m + m_idx; + activations_t[tranposed_idx] = activations[activation_idx]; + activation_qvals_t[tranposed_idx] = activation_qvals[activation_idx]; + } + } + activations = activations_t; + activation_qvals = activation_qvals_t; + } + + return std::make_tuple( + activations, activation_qvals, activation_scales, activation_zeros); +} + +struct channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case { + int m; + int k; + int n; + int stride; + + bool lhs_has_zeros; + bool rhs_has_zeros; + bool lhs_is_transposed; + bool rhs_is_transposed; + + std::vector expected_output; + + std::vector lhs; + std::vector lhs_qvals; + std::vector lhs_scales; + std::vector lhs_zeros; + + std::vector rhs; + std::vector rhs_qvals; + std::vector rhs_scales; + std::vector rhs_zeros; + + channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case( + int m_, + int k_, + int n_, + int stride_, + bool lhs_has_zeros_, + bool rhs_has_zeros_, + bool lhs_is_transposed_, + bool rhs_is_transposed_, + std::vector expected_output_, + std::vector lhs_, + std::vector lhs_qvals_, + std::vector lhs_scales_, + std::vector lhs_zeros_, + std::vector rhs_, + std::vector rhs_qvals_, + std::vector rhs_scales_, + std::vector rhs_zeros_) + : m(m_), + k(k_), + n(n_), + stride(stride_), + lhs_has_zeros(lhs_has_zeros_), + rhs_has_zeros(rhs_has_zeros_), + lhs_is_transposed(lhs_is_transposed_), + rhs_is_transposed(rhs_is_transposed_), + expected_output(expected_output_), + lhs(lhs_), + lhs_qvals(lhs_qvals_), + lhs_scales(lhs_scales_), + lhs_zeros(lhs_zeros_), + rhs(rhs_), + rhs_qvals(rhs_qvals_), + rhs_scales(rhs_scales_), + rhs_zeros(rhs_zeros_) { + assert(expected_output.size() == m * n); + assert(lhs.size() == m * stride * k); + assert(lhs_qvals.size() == m * stride * k); + assert(lhs_scales.size() == m * stride); + assert(lhs_zeros.size() == m * stride); + assert(rhs.size() == n * stride * k); + assert(rhs_qvals.size() == n * stride * k); + assert(rhs_scales.size() == n * stride); + assert(rhs_zeros.size() == n * stride); + } + + static channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case generate( + int m, + int k, + int n, + bool lhs_has_zeros, + bool rhs_has_zeros, + bool lhs_is_transposed, + // rhs_is_transposed means generated b matrix is mxk instead of kxm + bool rhs_is_transposed, + int stride = 1) { + assert(!lhs_is_transposed); + assert(lhs_has_zeros); + assert(rhs_has_zeros); + assert(rhs_is_transposed || stride == 1); + // Generate activations + auto [lhs, lhs_qvals, lhs_scales, lhs_zeros] = + generate_per_token_quantized_tensor(m * stride, k); + + auto [rhs, rhs_qvals, rhs_scales, rhs_zeros] = + generate_per_token_quantized_tensor(n * stride, k, !rhs_is_transposed); + // Above function produces nxk matrix and to produce kxn you need transposed + // = true. we do !rhs_is_transposed becaues when rhs_is_transposed = true + // the shape should be nxk instead of kxn. + + // Compute expected output + std::vector expected_output(m * n); + + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + float res = 0.0; + for (int k_idx = 0; k_idx < k; k_idx++) { + int lhs_idx = m_idx * stride * k + k_idx; + int rhs_idx = k_idx * stride * n + n_idx * stride; + if (rhs_is_transposed) { + rhs_idx = n_idx * stride * k + k_idx; + } + float lhs_dequant = lhs_scales[m_idx * stride] * + (lhs_qvals[lhs_idx] - lhs_zeros[m_idx * stride]); + + float rhs_dequant = rhs_scales[n_idx * stride] * + (rhs_qvals[rhs_idx] - rhs_zeros[n_idx * stride]); + + res += lhs_dequant * rhs_dequant; + } + expected_output[m_idx * n + n_idx] = res; + } + } + + // Return test case + return channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case( + m, + k, + n, + stride, + lhs_has_zeros, + rhs_has_zeros, + lhs_is_transposed, + rhs_is_transposed, + expected_output, + lhs, + lhs_qvals, + lhs_scales, + lhs_zeros, + rhs, + rhs_qvals, + rhs_scales, + rhs_zeros); + } +}; +} // namespace + +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_transposed> +struct test_channelwise_8bit_channelwise_8bit_b { + static void Run(int m, int k, int n); +}; + +template +struct test_channelwise_8bit_channelwise_8bit_b< + a_has_zeros, + b_has_zeros, + false, + true> { + static void Run(int m, int k, int n, int stride = 1) { + auto test_case = + channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case::generate( + m, k, n, a_has_zeros, a_has_zeros, false, true, stride); + + int a_stride_m, b_stride_n; + auto kernel = torchao::kernels::cpu::quantized_matmul:: + get_int8_a_int8_b_channelwise_qmatmul( + m, n, k, false, true, a_stride_m, b_stride_n); + a_stride_m = a_stride_m * stride; + b_stride_n = b_stride_n * stride; + + std::vector output(m * n); + kernel( + m, + n, + k, + test_case.lhs_qvals.data(), + a_stride_m /*lsh_stride_m*/, + test_case.rhs_qvals.data(), + b_stride_n /*rsh_stride_n*/, + output.data(), + n /*out_stride_n*/, + test_case.lhs_zeros.data(), + test_case.rhs_zeros.data(), + test_case.lhs_scales.data(), + test_case.rhs_scales.data(), + stride, /*lhs qparams stride*/ + stride /*rhs qparams stride*/); + + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } + } +}; + +TEST(test_channelwise_8bit_channelwise_8bit_b, TranposedBWithZeroPoints) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/1, /*k=*/128, /*n=*/16); +} + +TEST(test_channelwise_8bit_channelwise_8bit_b, TranposeBWithZeroPointsLargeM) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/128, /*n=*/16); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizes) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/24); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizes2) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/19); +} + +// Test shapes for which we have to use fallback kernel +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizesFallback) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/5); +} + +// Test shapes for which we have to use fallback kernel +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizesFallback2) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/2, /*n=*/1); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposeBWithZeroPointsLargeMStrided) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/128, /*n=*/16, 5); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizes2Strided) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/19, 16); +} + +// Test shapes for which we have to use fallback kernel +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizesFallbackStrided) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/5, 7); +} + +// Test shapes for which we have to use fallback kernel +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizesFallback2Strided) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/2, /*n=*/1, 32); +} From d741ff00f6284467a1a12967429e50aed32013d4 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Thu, 3 Apr 2025 10:23:32 -0700 Subject: [PATCH 16/30] Add fallback kernel and interface for rhs only quantized matmul Differential Revision: D71370602 Pull Request resolved: https://github.com/pytorch/ao/pull/2011 --- .../matmul/fp32_a_channelwise_8bit_b_fp32_c.h | 50 +++++ .../kernels/cpu/interface/quantized_matmul.h | 70 +++++++ .../cpu/interface/test_qmatmul_interface.cpp | 182 ++++++++++++++++++ 3 files changed, 302 insertions(+) create mode 100644 torchao/experimental/kernels/cpu/fallback/matmul/fp32_a_channelwise_8bit_b_fp32_c.h diff --git a/torchao/experimental/kernels/cpu/fallback/matmul/fp32_a_channelwise_8bit_b_fp32_c.h b/torchao/experimental/kernels/cpu/fallback/matmul/fp32_a_channelwise_8bit_b_fp32_c.h new file mode 100644 index 0000000000..58e2853617 --- /dev/null +++ b/torchao/experimental/kernels/cpu/fallback/matmul/fp32_a_channelwise_8bit_b_fp32_c.h @@ -0,0 +1,50 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +// TODO: Remove all ::kernels. No need for extra namespace. +namespace torchao::kernels::cpu::fallback::quantized_matmul { +namespace fp32_a_input_channelwise_8bit_b_fp32 { +template +void kernel( + int m, + int n, + int k, + const float* lhs, + int lhs_stride_m, + const int8_t* rhs, + int rhs_stride_n, + float* output, + int out_stride_m, + const int8_t* rhs_zero_points, + const float* rhs_scales, + const float beta, + const int rhs_qparams_stride) { + assert(a_transposed == false); + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + float res = 0.0; + for (int k_idx = 0; k_idx < k; k_idx++) { + int lhs_idx = m_idx * lhs_stride_m + k_idx; + int rhs_idx = k_idx * rhs_stride_n + n_idx; + if (b_transposed) { + rhs_idx = n_idx * rhs_stride_n + k_idx; + } + float rhs_dequant = rhs_scales[k_idx * rhs_qparams_stride] * + (static_cast(rhs[rhs_idx]) - + static_cast(rhs_zero_points[k_idx * rhs_qparams_stride])); + + res += lhs[lhs_idx] * rhs_dequant; + } + output[m_idx * n + n_idx] = output[m_idx * n + n_idx] * beta + res; + } + } +} +} // namespace fp32_a_input_channelwise_8bit_b_fp32 +} // namespace torchao::kernels::cpu::fallback::quantized_matmul diff --git a/torchao/experimental/kernels/cpu/interface/quantized_matmul.h b/torchao/experimental/kernels/cpu/interface/quantized_matmul.h index 01a4c704c5..718f7eaad9 100644 --- a/torchao/experimental/kernels/cpu/interface/quantized_matmul.h +++ b/torchao/experimental/kernels/cpu/interface/quantized_matmul.h @@ -9,6 +9,8 @@ #include #include +#include + #if defined(__aarch64__) || defined(__ARM_NEON) #include #include @@ -85,4 +87,72 @@ get_int8_a_int8_b_channelwise_qmatmul( channelwise_8bit_a_channelwise_8bit_b::kernel; } } + +/* +a_stride_m: stride of a in memory to indiciate how far apart each row is. +b_stride_n: stride of b in memory to indiciate how far apart each row is. +If b is transposed (n x k), then this is how many bytes to skip to get to the +next row. If b is not transposed (k x n), then this is how many bytes to skip to +get to the next row. + +It also returns the stride of a and b, that should be used in the kernel. + +Will need to think of a better way to find the right +ukernel. Perhaps via ukernelconfig + registry?. +*/ +using fp32_a_input_channelwise_8bit_b_f32_c_matmul_type = void (*)( + int, + int, + int, + const float*, + int, + const int8_t*, + int, + float*, + int, + const int8_t*, + const float*, + const float, + const int); + +fp32_a_input_channelwise_8bit_b_f32_c_matmul_type +get_fp32_a_input_channelwise_8bit_b_f32_c_matmul( + int m, + int n, + int k, + bool a_transposed, + bool b_transposed, + int& a_stride_m, + int& b_stride_n); + +fp32_a_input_channelwise_8bit_b_f32_c_matmul_type +get_fp32_a_input_channelwise_8bit_b_f32_c_matmul( + int m, + int n, + int k, + bool a_transposed, + bool b_transposed, + int& a_stride_m, + int& b_stride_n) { +#if defined(__aarch64__) || defined(__ARM_NEON) + if (!a_transposed && !b_transposed && n >= 16) { + a_stride_m = k; + b_stride_n = n; + return aarch64::quantized_matmul:: + fp32_a_input_channelwise_8bit_b_1x16x4_f32::kernel; + } +#endif // defined(__aarch64__) || defined(__ARM_NEON) + assert(!a_transposed); + if (b_transposed) { + a_stride_m = k; + b_stride_n = k; + return torchao::kernels::cpu::fallback::quantized_matmul:: + fp32_a_input_channelwise_8bit_b_fp32::kernel; + } else { + a_stride_m = k; + b_stride_n = n; + return torchao::kernels::cpu::fallback::quantized_matmul:: + fp32_a_input_channelwise_8bit_b_fp32::kernel; + } +} } // namespace torchao::kernels::cpu::quantized_matmul diff --git a/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp b/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp index 3629f0960b..4024f3f1de 100644 --- a/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp +++ b/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp @@ -446,3 +446,185 @@ TEST( Run( /*m=*/4, /*k=*/2, /*n=*/1, 32); } + +class FP32A_QuantizedB_FP32C_Interface_Test + : public ::testing::TestWithParam { + public: + int m; + int k; + int n; + int stride; + + bool rhs_has_zeros; + bool lhs_is_transposed; + bool rhs_is_transposed; + + std::vector init_output; + std::vector expected_output; + + std::vector lhs; + + std::vector rhs; + std::vector rhs_qvals; + std::vector rhs_scales; + std::vector rhs_zeros; + + void generate( + int m_, + int k_, + int n_, + bool rhs_has_zeros_, + bool lhs_is_transposed_, + bool rhs_is_transposed_, + int stride_ = 1) { + assert(!lhs_is_transposed_); + assert(rhs_has_zeros_); + m = m_; + k = k_; + n = n_; + stride = stride_; + rhs_has_zeros = rhs_has_zeros_; + lhs_is_transposed = lhs_is_transposed_; + rhs_is_transposed = rhs_is_transposed_; + + assert(!rhs_is_transposed || stride == 1); + + // Generate activations + lhs = get_random_vector(m * k, -1.0, 1.0); + + // The strange thing this is doing is that instead of quantizing + // each output channel separately, we are quantizing each input channel + // Reason why we do !rhs_is_transposed is because + // we actually want k x n matrix not n x k matrix + // because each input channel is quantized separately + std::tie(rhs, rhs_qvals, rhs_scales, rhs_zeros) = + generate_per_token_quantized_tensor(k * stride, n, rhs_is_transposed); + + // Compute expected output + init_output = get_random_vector(m * n, -1.0, 1.0); + + assert(init_output.size() == m * n); + assert(lhs.size() == m * k); + assert(rhs.size() == n * stride * k); + assert(rhs_qvals.size() == n * stride * k); + assert(rhs_scales.size() == k * stride); + assert(rhs_zeros.size() == k * stride); + } + + void execute(float beta) { + // Compute expected output + expected_output = init_output; + + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + float res = 0.0; + for (int k_idx = 0; k_idx < k; k_idx++) { + int lhs_idx = m_idx * k + k_idx; + int rhs_idx = k_idx * stride * n + n_idx; + if (rhs_is_transposed) { + rhs_idx = n_idx * k * stride + k_idx * stride; + } + float rhs_dequant = rhs_scales[k_idx * stride] * + (static_cast(rhs_qvals[rhs_idx]) - + static_cast(rhs_zeros[k_idx * stride])); + + res += lhs[lhs_idx] * rhs_dequant; + } + expected_output[m_idx * n + n_idx] = + expected_output[m_idx * n + n_idx] * beta + res; + } + } + } + + float beta() const { + return GetParam(); + } +}; + +static void test_fp32_a_input_channelwise_8bit_b( + int m, + int k, + int n, + float beta, + FP32A_QuantizedB_FP32C_Interface_Test& test_case, + int stride = 1) { + test_case.execute(beta); + + int a_stride_m, b_stride_n; + auto kernel = torchao::kernels::cpu::quantized_matmul:: + get_fp32_a_input_channelwise_8bit_b_f32_c_matmul( + m, n, k, false, false, a_stride_m, b_stride_n); + b_stride_n = b_stride_n * stride; + + std::vector output(test_case.init_output); + kernel( + m, + n, + k, + test_case.lhs.data(), + a_stride_m /*lhs_stride_m*/, + test_case.rhs_qvals.data(), + b_stride_n /*rhs_stride_n*/, + output.data(), + n /*out_stride_n*/, + test_case.rhs_zeros.data(), + test_case.rhs_scales.data(), + beta, + stride /*rhs qparams stride*/); + + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } +} + +TEST_P(FP32A_QuantizedB_FP32C_Interface_Test, BTranposedWithZeroPoints) { + generate(3, 128, 16, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/3, /*k=*/128, /*n=*/16, beta(), *this); +} + +TEST_P( + FP32A_QuantizedB_FP32C_Interface_Test, + BTranposedWithZeroPointsOddSizes) { + generate(4, 37, 19, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/4, /*k=*/37, /*n=*/19, beta(), *this); +} + +// Test shapes for which we have to use fallback kernel +TEST_P( + FP32A_QuantizedB_FP32C_Interface_Test, + BTranposedWithZeroPointsOddSizesFallback) { + generate(4, 37, 3, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/4, /*k=*/37, /*n=*/3, beta(), *this); +} + +TEST_P( + FP32A_QuantizedB_FP32C_Interface_Test, + BTranposedWithZeroPointsOddSizes2Fallback) { + generate(4, 1, 3, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/4, /*k=*/1, /*n=*/3, beta(), *this); +} + +TEST_P( + FP32A_QuantizedB_FP32C_Interface_Test, + BTranposedWithZeroPointsOddSizesStrided) { + generate(4, 37, 19, true, false, false, 32); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/4, /*k=*/37, /*n=*/19, beta(), *this, 32); +} + +TEST_P( + FP32A_QuantizedB_FP32C_Interface_Test, + BTranposedWithZeroPointsOddSizes2FallbackStrided) { + generate(4, 5, 3, true, false, false, 32); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/4, /*k=*/5, /*n=*/3, beta(), *this, 32); +} + +INSTANTIATE_TEST_SUITE_P( + F32AInt8BFP32CTest, + FP32A_QuantizedB_FP32C_Interface_Test, + ::testing::Values(0.0, 1.0, 3.1)); From e190329c397164486170a456ac278c2e74daad04 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Thu, 3 Apr 2025 11:16:06 -0700 Subject: [PATCH 17/30] Add KleidiAI gemm kernels (#2000) Add KleidiAI gemm kernels (#2000) Summary: This PR pulls in two new KleidiAI kernels: * kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod (GEMV) * kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod (GEMM) and adds them for automatic mr-based kernel selection when TORCHAO_ENABLE_ARM_NEON_DOT is set. It also adds new tests for these kernels, and refactors the kleidiai testing code so that in future new kleidiai kernels can be tested with a one line addition: ``` TEST( test_linear_8bit_act_xbit_weight, matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod) { test_linear_8bit_act_xbit_weight_kleidiai< matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod>(); } ``` The exisitng testing code (still exists for more coverage) depended on code generation. Reviewed By: Jack-Khuu Differential Revision: D72179835 --- .../workflows/torchao_experimental_test.yml | 2 +- torchao/experimental/CMakeLists.txt | 1 + .../kernels/cpu/aarch64/CMakeLists.txt | 2 +- .../kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h | 11 ++- .../kernels/cpu/aarch64/tests/CMakeLists.txt | 1 + .../kernel_selector.h | 65 ++++++++++----- torchao/experimental/ops/tests/CMakeLists.txt | 1 + .../test_linear_8bit_act_xbit_weight.cpp | 82 ++++++++++++++++++- ...est_int8_dynamic_activation_intx_weight.py | 70 +++++++++++++++- 9 files changed, 207 insertions(+), 28 deletions(-) diff --git a/.github/workflows/torchao_experimental_test.yml b/.github/workflows/torchao_experimental_test.yml index 0cb470901e..2187eed8e3 100644 --- a/.github/workflows/torchao_experimental_test.yml +++ b/.github/workflows/torchao_experimental_test.yml @@ -38,7 +38,7 @@ jobs: pip install executorch pip install torch==2.7.0.dev20250311 --index-url "https://download.pytorch.org/whl/nightly/cpu" --force-reinstall pip install -r dev-requirements.txt - USE_CPP=1 TOCHAO_BUILD_KLEIDIAI=1 pip install . + USE_CPP=1 TORCHAO_BUILD_KLEIDIAI=1 pip install . - name: Run python tests run: | conda activate venv diff --git a/torchao/experimental/CMakeLists.txt b/torchao/experimental/CMakeLists.txt index f05e6b392f..e6b2a6aff0 100644 --- a/torchao/experimental/CMakeLists.txt +++ b/torchao/experimental/CMakeLists.txt @@ -40,6 +40,7 @@ include_directories(${TORCHAO_INCLUDE_DIRS}) if(TORCHAO_BUILD_CPU_AARCH64) message(STATUS "Building with cpu/aarch64") add_compile_definitions(TORCHAO_BUILD_CPU_AARCH64) + add_compile_definitions(TORCHAO_ENABLE_ARM_NEON_DOT) # Defines torchao_kernels_aarch64 add_subdirectory(kernels/cpu/aarch64) diff --git a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt index 3cca338cbf..f38794d4a8 100644 --- a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt @@ -19,7 +19,7 @@ if (TORCHAO_BUILD_CPU_AARCH64) # intelligence (AI) workloads tailored for Arm® CPUs. FetchContent_Declare(kleidiai GIT_REPOSITORY https://git.gitlab.arm.com/kleidi/kleidiai.git - GIT_TAG v1.2.0) + GIT_TAG v1.5.0) FetchContent_MakeAvailable(kleidiai) target_link_libraries(torchao_kernels_aarch64 PUBLIC kleidiai) diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h index 2a8e668fa7..aa338fc165 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h @@ -14,9 +14,14 @@ #include #include +#include + +#ifdef TORCHAO_ENABLE_ARM_NEON_DOT +#include #include #include -#include +#include +#endif // TORCHAO_ENABLE_ARM_NEON_DOT #ifdef TORCHAO_ENABLE_ARM_I8MM #include @@ -297,10 +302,14 @@ size_t get_preferred_alignement() { } \ } +#ifdef TORCHAO_ENABLE_ARM_NEON_DOT DEFINE_KERNEL_STRUCT( matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod); DEFINE_KERNEL_STRUCT( matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod); +DEFINE_KERNEL_STRUCT(matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod); +DEFINE_KERNEL_STRUCT(matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod); +#endif // TORCHAO_ENABLE_ARM_NEON_DOT #ifdef TORCHAO_ENABLE_ARM_I8MM DEFINE_KERNEL_STRUCT(matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm); diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt index a01afac68f..db736d84a3 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt @@ -42,6 +42,7 @@ add_subdirectory(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64 $ if(TORCHAO_BUILD_KLEIDIAI) add_compile_definitions(TORCHAO_ENABLE_KLEIDI) + add_compile_definitions(TORCHAO_ENABLE_ARM_NEON_DOT) endif() if(TORCHAO_BUILD_ARM_I8MM) diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h index 719c2e01e4..ffdd62f7a7 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h @@ -8,18 +8,19 @@ #include #include #include - -#if defined(TORCHAO_BUILD_CPU_AARCH64) -#include -#endif // TORCHAO_BUILD_CPU_AARCH64 - #include #include #include +#if defined(TORCHAO_BUILD_CPU_AARCH64) +#if defined(TORCHAO_ENABLE_ARM_NEON_DOT) +#include +#endif // TORCHAO_ENABLE_ARM_NEON_DOT + #if defined(TORCHAO_ENABLE_KLEIDI) #include #endif // TORCHAO_ENABLE_KLEIDI +#endif // TORCHAO_BUILD_CPU_AARCH64 namespace torchao::ops::linear_8bit_act_xbit_weight { @@ -110,7 +111,7 @@ void register_ukernel_config_universal( constexpr int mr = 1; constexpr int m_step = 1; -#if defined(TORCHAO_BUILD_CPU_AARCH64) +#if defined(TORCHAO_ENABLE_ARM_NEON_DOT) if (cpuinfo_has_arm_neon_dot()) { log_registration(format, "universal: kernel_1x8x16_f32_neondot"); auto uk = UKernelConfig::make( @@ -159,7 +160,7 @@ void register_ukernel_config_universal( return; } } -#endif // TORCHAO_BUILD_CPU_AARCH64 +#endif // TORCHAO_ENABLE_ARM_NEON_DOT } } @@ -213,18 +214,24 @@ void register_ukernel_config_kleidi( #if defined(TORCHAO_ENABLE_ARM_I8MM) if (cpuinfo_has_arm_i8mm()) { - /*m_step=4*/ + log_registration( + format, + "kleidiai: matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm"); + /*m_step=1*/ uk.linear_configs[0] = get_linear_config_kleidi< + op::matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod>( + uk.n_step, uk.nr, uk.kr, uk.sr); + + /*m_step=4*/ + uk.linear_configs[1] = get_linear_config_kleidi< op::matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm>( uk.n_step, uk.nr, uk.kr, uk.sr); - log_registration( - format, - "kleidiai: matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm"); table.register_ukernel_config(format, uarch, std::move(uk)); return; } #endif // TORCHAO_ENABLE_ARM_I8MM +#if defined(TORCHAO_ENABLE_ARM_NEON_DOT) if (cpuinfo_has_arm_neon_dot()) { log_registration( format, @@ -236,22 +243,27 @@ void register_ukernel_config_kleidi( table.register_ukernel_config(format, uarch, std::move(uk)); return; } +#endif // TORCHAO_ENABLE_ARM_NEON_DOT } - if (format.nr == 4 && format.kr == 16 && format.sr == 2) { - uk.n_step = 4; + if (format.nr == 8 && format.kr == 8 && format.sr == 2) { +#if defined(TORCHAO_ENABLE_ARM_NEON_DOT) if (cpuinfo_has_arm_neon_dot()) { - /*m_step=1*/ - uk.linear_configs[0] = get_linear_config_kleidi< - op::matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod>( - uk.n_step, uk.nr, uk.kr, uk.sr); - log_registration( format, - "kleidiai: matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod"); + "kleidiai: matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod, matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod"); + // m_step 1 + uk.linear_configs[0] = get_linear_config_kleidi< + op::matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod>( + uk.n_step, uk.nr, uk.kr, uk.sr); + // m_step 4 + uk.linear_configs[1] = get_linear_config_kleidi< + op::matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod>( + uk.n_step, uk.nr, uk.kr, uk.sr); table.register_ukernel_config(format, uarch, std::move(uk)); return; } +#endif // TORCHAO_ENABLE_ARM_NEON_DOT } } #endif // TORCHAO_ENABLE_KLEIDI @@ -325,8 +337,7 @@ PackedWeightsFormat select_packed_weights_format( #if defined(TORCHAO_ENABLE_KLEIDI) if (!target || *target == "kleidiai") { if (weight_nbit == 4 && (!has_weight_zeros)) { - // KleidiAI will pack bias with weights always, - // even if bias is not provided 0s will be packed +#if defined(TORCHAO_ENABLE_ARM_I8MM) return PackedWeightsFormat( torchao::ops::PackedWeightsType::kleidi_ai, weight_nbit, @@ -335,12 +346,23 @@ PackedWeightsFormat select_packed_weights_format( /*nr*/ 8, /*kr*/ 16, /*sr*/ 2); +#elif defined(TORCHAO_ENABLE_ARM_NEON_DOT) + return PackedWeightsFormat( + torchao::ops::PackedWeightsType::kleidi_ai, + weight_nbit, + has_weight_zeros, + has_bias, + /*nr*/ 8, + /*kr*/ 8, + /*sr*/ 2); +#endif } } #endif // defined(TORCHAO_ENABLE_KLEIDI) // Select universal format if (!target || *target == "universal") { +#if defined(TORCHAO_ENABLE_ARM_NEON_DOT) return PackedWeightsFormat( torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal, weight_nbit, @@ -349,6 +371,7 @@ PackedWeightsFormat select_packed_weights_format( /*nr*/ 8, /*kr*/ 16, /*sr*/ 2); +#endif // defined(TORCHAO_ENABLE_ARM_NEON_DOT) } throw std::runtime_error("No packed_weights_format was selected"); diff --git a/torchao/experimental/ops/tests/CMakeLists.txt b/torchao/experimental/ops/tests/CMakeLists.txt index 8a9ad08f23..8245fdd746 100644 --- a/torchao/experimental/ops/tests/CMakeLists.txt +++ b/torchao/experimental/ops/tests/CMakeLists.txt @@ -24,6 +24,7 @@ enable_testing() if(TORCHAO_BUILD_CPU_AARCH64) add_compile_definitions(TORCHAO_BUILD_CPU_AARCH64=1) + add_compile_definitions(TORCHAO_ENABLE_ARM_NEON_DOT) endif() if(TORCHAO_BUILD_KLEIDIAI) diff --git a/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp b/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp index 980228a1a8..1d4127a43e 100644 --- a/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp +++ b/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp @@ -21,7 +21,7 @@ using namespace torchao::kernels::cpu::aarch64::kleidi:: #endif // TORCHAO_ENABLE_KLEIDI const float kTol = 1.0e-5; -const float kTolKleidiAI = 1.0e-2; +const float kTolKleidiAI = 5.0e-2; using namespace torchao::ops::linear_8bit_act_xbit_weight; @@ -208,6 +208,86 @@ UKernelConfig get_ukernel_config_kleidi_impl() { return ukernel_config; } +template +void test_linear_8bit_act_xbit_weight_kleidiai() { + constexpr int weight_nbit = 4; + constexpr bool has_kleidi = true; + constexpr bool has_weight_zeros = false; + constexpr bool has_bias = true; + auto uk = get_ukernel_config_kleidi_impl(); + + for (auto m : {1, 3, 4, 8, 9, 13, 21, 43, 101}) { + for (auto n : + {1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 4 * 13, + 4 * 13 + 3, + 8 * 13, + 8 * 13 + 3, + 16 * 13, + 16 * 13 + 3}) { + for (auto k : {32, 64, 128}) { + int group_size = 32; + test_linear_8bit_act_xbit_weight< + weight_nbit, + has_weight_zeros, + has_bias, + /*has_clamp*/ true, + has_kleidi>(m, n, k, group_size, &uk); + test_linear_8bit_act_xbit_weight< + weight_nbit, + has_weight_zeros, + has_bias, + /*has_clamp*/ false, + has_kleidi>(m, n, k, group_size, &uk); + + if (k >= 64) { + group_size = 64; + test_linear_8bit_act_xbit_weight< + weight_nbit, + has_weight_zeros, + has_bias, + /*has_clamp*/ true, + has_kleidi>(m, n, k, group_size, &uk); + } + } + } + } +} + +#if defined(TORCHAO_ENABLE_ARM_NEON_DOT) +TEST( + test_linear_8bit_act_xbit_weight, + matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod) { + test_linear_8bit_act_xbit_weight_kleidiai< + matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod>(); +} +TEST( + test_linear_8bit_act_xbit_weight, + matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod) { + test_linear_8bit_act_xbit_weight_kleidiai< + matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod>(); +} +TEST( + test_linear_8bit_act_xbit_weight, + matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod) { + test_linear_8bit_act_xbit_weight_kleidiai< + matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod>(); +} +TEST( + test_linear_8bit_act_xbit_weight, + matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod) { + test_linear_8bit_act_xbit_weight_kleidiai< + matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod>(); +} +#endif // TORCHAO_ENABLE_ARM_NEON_DOT + template UKernelConfig get_ukernel_config_kleidi() { #if defined(TORCHAO_ENABLE_ARM_I8MM) diff --git a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py index 098fc09696..dcd8eb74d5 100644 --- a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py +++ b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py @@ -105,6 +105,58 @@ def test_accuracy(self, layout, weight_dtype, has_weight_zeros, granularity): expected_result = quantized_model_reference(activations) self._assert_close(result, expected_result) + def test_accuracy_kleidiai(self): + n = 1071 + k = 2048 + model = torch.nn.Sequential( + *[torch.nn.Linear(k, k, bias=False), torch.nn.Linear(k, n, bias=True)] + ) + weight_dtype = torch.int4 + granularity = PerGroup(128) + has_weight_zeros = False + + # We set round_weight_scale_to_bf16 to True for accuracy testing because + # some KleidiAI kernels do this internally + round_weight_scale_to_bf16 = True + + quantized_model = copy.deepcopy(model) + quantize_( + quantized_model, + int8_dynamic_activation_intx_weight( + weight_dtype=weight_dtype, + granularity=granularity, + has_weight_zeros=has_weight_zeros, + layout=PackedLinearInt8DynamicActivationIntxWeightLayout( + target="kleidiai" + ), + round_weight_scale_to_bf16=round_weight_scale_to_bf16, + ), + ) + + quantized_model_reference = copy.deepcopy(model) + quantize_( + quantized_model_reference, + int8_dynamic_activation_intx_weight( + weight_dtype=weight_dtype, + granularity=granularity, + has_weight_zeros=has_weight_zeros, + layout=self._reference_layout(), + round_weight_scale_to_bf16=round_weight_scale_to_bf16, + ), + ) + + with torch.no_grad(): + for m in [1, 3, 5, 9, 13]: + activations = torch.randn(m, k) + result = quantized_model(activations) + expected_result = quantized_model_reference(activations) + + # KleidiAI kernels require much higher tolerance when comparing to reference, + # especially for GEMM kernels + self._assert_close( + result, expected_result, mse_tol=1e-2, atol=1e-2, rtol=1 + ) + def test_accuracy_aten(self): m = 3 n = 1024 @@ -151,9 +203,21 @@ def test_accuracy_aten(self): self._assert_close(result, expected_result) - def _assert_close(self, result, expected_result): - self.assertTrue(torch.nn.functional.mse_loss(result, expected_result) <= 1e-6) - self.assertTrue(torch.allclose(result, expected_result, atol=1e-2)) + def _assert_close( + self, result, expected_result, mse_tol=1e-6, atol=1e-2, rtol=1e-5 + ): + mse_loss = torch.nn.functional.mse_loss(result, expected_result) + self.assertTrue( + mse_loss <= mse_tol, + f"Got mse_loss={mse_loss}, above mse tolerance {mse_tol}", + ) + + n_rand_idxs = 5 + rand_idxs = torch.randint(0, result.numel(), (n_rand_idxs,)) + self.assertTrue( + torch.allclose(result, expected_result, atol=atol, rtol=rtol), + f"Failed allclose at atol={atol}, rtol={rtol}. On {n_rand_idxs} random indices, we have result={result.reshape(-1)[rand_idxs]} vs expected_result={expected_result.reshape(-1)[rand_idxs]}.", + ) def _reference_layout(self): return PlainLayout() From 540951589431ca5ae044925cfcea707e63b6bee4 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 3 Apr 2025 17:00:59 -0700 Subject: [PATCH 18/30] Update float8nocompile test code to use new float8 matmul function (#2013) --- .../prototype/float8nocompile/float8nocompile_linear_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/prototype/float8nocompile/float8nocompile_linear_test.py b/torchao/prototype/float8nocompile/float8nocompile_linear_test.py index f62569cbb4..7df5ce768c 100644 --- a/torchao/prototype/float8nocompile/float8nocompile_linear_test.py +++ b/torchao/prototype/float8nocompile/float8nocompile_linear_test.py @@ -7,7 +7,7 @@ import torch from torchao.float8.config import Float8LinearConfig -from torchao.float8.float8_linear import manual_float8_matmul_with_args_in_hp +from torchao.float8.float8_linear import matmul_with_hp_or_float8_args from torchao.float8.float8_tensor import LinearMMConfig, ScaledMMConfig from torchao.prototype.float8nocompile.float8nocompile_linear import ( matmul_with_args_in_hp, @@ -72,7 +72,7 @@ def test_matmul_with_args_in_hp(input_shape: tuple[int, int]): ) # prod forward. expects transposed weight. - out_prod = manual_float8_matmul_with_args_in_hp.apply( + out_prod = matmul_with_hp_or_float8_args.apply( prod_input_bf16, prod_weight_bf16.t(), linear_mm_config, config ) From 916f9d75eccb5c0d69c9e096e4a12919ed45060b Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 3 Apr 2025 20:41:37 -0700 Subject: [PATCH 19/30] Remove float8nocompile CI (#1976) remove float8nocompmile CI since it's flaky on sm89 --- .github/workflows/float8nocompile_test.yaml | 53 --------------------- 1 file changed, 53 deletions(-) delete mode 100644 .github/workflows/float8nocompile_test.yaml diff --git a/.github/workflows/float8nocompile_test.yaml b/.github/workflows/float8nocompile_test.yaml deleted file mode 100644 index b8707c148e..0000000000 --- a/.github/workflows/float8nocompile_test.yaml +++ /dev/null @@ -1,53 +0,0 @@ -name: Run Float8nocompile Tests - -on: - push: - branches: - - main - - 'gh/**' - paths: - - 'torchao/prototype/float8nocompile/**' - pull_request: - branches: - - main - - 'gh/**' - paths: - - 'torchao/prototype/float8nocompile/**' - -concurrency: - group: floatnocompile_test-${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }} - cancel-in-progress: true - -env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} - -# jobs: -# test: -# strategy: -# fail-fast: false -# matrix: -# include: -# - name: H100 -# runs-on: linux.aws.h100 -# torch-spec: '--pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124' -# gpu-arch-type: "cuda" -# gpu-arch-version: "12.4" - -# uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main -# with: -# timeout: 300 -# runner: ${{ matrix.runs-on }} -# gpu-arch-type: ${{ matrix.gpu-arch-type }} -# gpu-arch-version: ${{ matrix.gpu-arch-version }} -# submodules: recursive -# script: | -# conda create -n venv python=3.9 -y -# conda activate venv -# export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH -# python -m pip install --upgrade pip -# pip install ${{ matrix.torch-spec }} -# pip install -r dev-requirements.txt -# pip install . -# cd torchao/prototype/float8nocompile -# pytest kernels/ --verbose -s -# pytest test/train_test.py --verbose -s From 0436d359af491aba4d5e52ddbf9c4c9382f385a3 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 4 Apr 2025 06:20:39 -0700 Subject: [PATCH 20/30] Update clean_release_notes.py (#2014) --- scripts/clean_release_notes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/clean_release_notes.py b/scripts/clean_release_notes.py index 2caef0735b..92ce5996cc 100644 --- a/scripts/clean_release_notes.py +++ b/scripts/clean_release_notes.py @@ -223,7 +223,7 @@ def format_commit(commit_line: str) -> str: After: * Commit title (https://github.com/pytorch/ao/pull/123) """ # Remove author, put PR link in parentheses - commit_line = re.sub(" by @.* in (.*)", r" (\\g<1>)", commit_line) + commit_line = re.sub(" by @.* in (.*)", r" (\g<1>)", commit_line) # Capitalize first letter commit_line = commit_line.lstrip("* ") commit_line = "* " + commit_line[0].upper() + commit_line[1:] From 8ae4b6ab5bb009135ef2a374b9bb7aa8fd49af5b Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 4 Apr 2025 10:56:28 -0400 Subject: [PATCH 21/30] Match QAT prepare and convert numerics exactly (#1964) **Summary:** Previously, `Int8DynActInt4QATQuantizer` had slightly diverging numerics between the prepare and convert steps. This is because the prepare step uses quantization primitives shared with AQT (specifically `quantize_affine` and `dequantize_affine`), while the convert step relies on old ops from the `torch.ops.quantized_decomposed` namespace. The diverging numerics is negligible for small models, but the quantization errors begin to compound for larger models with many linear layers. More specifically, there are three different places where the divergence occurs during activation quantization: 1. **Choose qparams.** The prepare step casts the qparams to `torch.float32`, whereas the convert step casts the scales to `torch.float64` and zero points to `torch.int64`. 2. **Quantize.** The prepare step performs round before adding zero points and uses torch functions, while the convert step adds before rounding and uses torch tensor methods. ``` x = torch.clamp( torch.round(x * (1.0 / scale)) + zero_point, qmin, qmax, ) x = ( x.mul(1.0 / scale) .add(zero_point) .round() .clamp(qmin, qmax) .to(quantize_dtype) ) ``` 3. **Dequantize.** The prepare step casts to `torch.int32` before adding the zero points, and casts back to the original dtype before multiplying the scale. The convert step only casts at the very end. ``` x = x.to(torch.int32) - zero_point.to(torch.int32) x = x.to(orig_dtype) x = x * scale x = x - zero_point x = x * scale x = x.to(orig_dtype) ``` This commit makes the convert path use the same torchao quantization primitives as the prepare path, thereby resolving the 3 above differences. Now, the prepare and convert steps match exactly in terms of numerics over many trials. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_per_token_vs_convert python test/quantization/test_qat.py -k test_qat_8da4w_prepare_vs_convert --- test/quantization/test_qat.py | 71 +++++++++++++++++++++++++++++++++++ torchao/_executorch_ops.py | 2 + torchao/quantization/GPTQ.py | 17 +++++---- torchao/quantization/utils.py | 60 +++++++++++++++++------------ 4 files changed, 118 insertions(+), 32 deletions(-) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 3c29028898..fcd4969bbf 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -133,6 +133,18 @@ def forward(self, x): return x +class M4(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(512, 256, bias=False).to(torch.float) + + def example_inputs(self): + return (torch.randn(1, 512).to(torch.float),) + + def forward(self, x): + return self.linear(x) + + class ModelWithLinearBias(torch.nn.Module): def __init__(self): super().__init__() @@ -1389,6 +1401,65 @@ def test_qat_linear_bias(self): example_inputs = m.example_inputs() m(*example_inputs) + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) + def test_fake_quantize_per_token_vs_convert(self): + """ + Test that the following produce the exact same numerics: + 1. FakeQuantizer with asymmetric per_token config + 2. torchao.quantization.utils.per_token_dynamic_quant + """ + from torchao.quantization.utils import per_token_dynamic_quant + + torch.manual_seed(self.SEED) + x = torch.randn(1, 235, 2048) + config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) + fake_quantizer = FakeQuantizer(config) + fake_quantizer_out = fake_quantizer(x) + baseline_out = per_token_dynamic_quant(x) + torch.testing.assert_close(fake_quantizer_out, baseline_out, atol=0, rtol=0) + + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) + def test_qat_8da4w_prepare_vs_convert(self): + """ + Test that the prepare and convert steps of Int8DynActInt4QATQuantizer produces + numerics that match exactly over N trials. + """ + from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer + from torchao.quantization.utils import compute_error + + num_trials = 1000 + group_size = 16 + non_inf_sqnr = [] + + for seed in range(self.SEED, self.SEED + num_trials): + torch.manual_seed(seed) + m = M4() + torch.manual_seed(seed) + x = m.example_inputs() + + quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) + prepared = quantizer.prepare(m) + prepared_out = prepared(*x) + converted = quantizer.convert(prepared) + converted_out = converted(*x) + sqnr = compute_error(prepared_out, converted_out).item() + if sqnr != float("inf"): + non_inf_sqnr.append(sqnr) + + avg_sqnr = ( + sum(non_inf_sqnr) / len(non_inf_sqnr) if len(non_inf_sqnr) > 0 else -1 + ) + fail_message = "%s/%s trials did not match exactly, average sqnr = %s" % ( + len(non_inf_sqnr), + num_trials, + avg_sqnr, + ) + self.assertEqual(len(non_inf_sqnr), 0, fail_message) + if __name__ == "__main__": unittest.main() diff --git a/torchao/_executorch_ops.py b/torchao/_executorch_ops.py index 29339bba8c..4b761ad725 100644 --- a/torchao/_executorch_ops.py +++ b/torchao/_executorch_ops.py @@ -5,6 +5,8 @@ # LICENSE file in the root directory of this source tree. import torch +# TODO: delete these ops + def _quantized_decomposed_quantize_per_channel_group_wrapper(*args, **kwargs): """ diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index 6c63937051..63b1da440d 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -24,7 +24,10 @@ find_multiple, ) -from .quant_primitives import MappingType +from .quant_primitives import ( + MappingType, + dequantize_affine, +) from .unified import Quantizer from .utils import ( _MultiInput, @@ -940,19 +943,17 @@ def linear_forward_8da4w( n_bit = 4 quant_min = -(2 ** (n_bit - 1)) quant_max = 2 ** (n_bit - 1) - 1 - from torchao._executorch_ops import ( - _quantized_decomposed_dequantize_per_channel_group_wrapper, - ) + block_size = (1, groupsize) - w_dq = _quantized_decomposed_dequantize_per_channel_group_wrapper( + w_dq = dequantize_affine( weight_int8, + block_size, scales, zeros, + torch.int8, quant_min, quant_max, - torch.int8, - groupsize, - precision, + output_dtype=precision, ) # x = x.to(torch.float16) diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 74c136ad00..b23f39c6d7 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -539,36 +539,48 @@ def group_quantize_tensor_symmetric( return w_int8, scales, zeros -def per_token_dynamic_quant(input: torch.Tensor) -> torch.Tensor: - orig_dtype = input.dtype - # TODO: we may need to make the choose_qparams op configurable - from torchao._executorch_ops import ( - _quantized_decomposed_choose_qparams_per_token_asymmetric_wrapper, - ) - - ( - scales, - zero_points, - ) = _quantized_decomposed_choose_qparams_per_token_asymmetric_wrapper( - input, torch.int8 - ) - - # TODO: get these from torch.int8 +def per_token_dynamic_quant( + input: torch.Tensor, + scale_dtype: torch.dtype = torch.float32, + zero_point_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + mapping_type = MappingType.ASYMMETRIC + block_size = _get_per_token_block_size(input) quant_min = -128 quant_max = 127 - from torchao._executorch_ops import _quantized_decomposed_quantize_per_token_wrapper + quant_dtype = torch.int8 + output_dtype = input.dtype - input = _quantized_decomposed_quantize_per_token_wrapper( - input, scales, zero_points, quant_min, quant_max, torch.int8 + scales, zero_points = choose_qparams_affine( + input, + mapping_type, + block_size, + quant_dtype, + quant_min, + quant_max, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, ) - from torchao._executorch_ops import ( - _quantized_decomposed_dequantize_per_token_wrapper, + q = quantize_affine( + input, + block_size, + scales, + zero_points, + quant_dtype, + quant_min, + quant_max, ) - - input = _quantized_decomposed_dequantize_per_token_wrapper( - input, scales, zero_points, quant_min, quant_max, torch.int8, orig_dtype + dq = dequantize_affine( + q, + block_size, + scales, + zero_points, + quant_dtype, + quant_min, + quant_max, + output_dtype=output_dtype, ) - return input.to(orig_dtype) + return dq def recommended_inductor_config_setter(): From 90bff95ca991170730541de49f201c910ab17e7a Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Fri, 4 Apr 2025 12:38:38 -0700 Subject: [PATCH 22/30] Skip failing tests for rowwise-scaled (#2022) stack-info: PR: https://github.com/pytorch/ao/pull/2022, branch: drisspg/stack/46 --- test/test_ops_rowwise_scaled_linear_cutlass.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/test_ops_rowwise_scaled_linear_cutlass.py b/test/test_ops_rowwise_scaled_linear_cutlass.py index f9b9c6a7f9..72bb201b3f 100644 --- a/test/test_ops_rowwise_scaled_linear_cutlass.py +++ b/test/test_ops_rowwise_scaled_linear_cutlass.py @@ -16,6 +16,7 @@ _int4_symm_cutlass_quant, _int8_symm_cutlass_quant, ) +from torchao.testing.utils import get_compute_capability DTYPES = [torch.float16, torch.bfloat16] BATCH_SIZE = [1, 4, 8, 16, 32, 64] @@ -87,6 +88,7 @@ def run_test_for_op(op, dtype, batch_size, size_mnk, use_bias): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(get_compute_capability() != 8.0, reason="Only supported on A100") @pytest.mark.parametrize("dtype, batch_size, size_mnk, use_bias", TEST_PARAMS) def test_rowwise_scaled_linear_cutlass_s4s4(dtype, batch_size, size_mnk, use_bias): run_test_for_op( @@ -99,6 +101,7 @@ def test_rowwise_scaled_linear_cutlass_s4s4(dtype, batch_size, size_mnk, use_bia @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(get_compute_capability() != 8.0, reason="Only supported on A100") @pytest.mark.parametrize("dtype, batch_size, size_mnk, use_bias", TEST_PARAMS) def test_rowwise_scaled_linear_cutlass_s8s4(dtype, batch_size, size_mnk, use_bias): run_test_for_op( From 711d58435ca172c454cbe93b18b1e70a9099c06e Mon Sep 17 00:00:00 2001 From: Lisa Jin Date: Fri, 4 Apr 2025 17:42:02 -0400 Subject: [PATCH 23/30] Update torchao.prototype.parq and add 4-bit Llama 3.2 1B benchmark (#2017) Replace torchao.prototype.parq with facebookresearch/parq submodule --- torchao/prototype/parq/__init__.py | 6 + torchao/prototype/parq/optim/__init__.py | 6 + torchao/prototype/parq/optim/binarelax.py | 2 +- torchao/prototype/parq/optim/parq.py | 21 ++- torchao/prototype/parq/optim/proxmap.py | 1 + torchao/prototype/parq/optim/quantopt.py | 122 +++++++++----- torchao/prototype/parq/quant/__init__.py | 12 +- torchao/prototype/parq/quant/lsbq.py | 126 +++++++++++--- torchao/prototype/parq/quant/quantizer.py | 5 + torchao/prototype/parq/quant/uniform.py | 196 +++++++++++++++++++--- torchao/prototype/parq/utils.py | 14 +- 11 files changed, 420 insertions(+), 91 deletions(-) diff --git a/torchao/prototype/parq/__init__.py b/torchao/prototype/parq/__init__.py index 07353d2461..2239139495 100644 --- a/torchao/prototype/parq/__init__.py +++ b/torchao/prototype/parq/__init__.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + from .optim import ( # noqa: F401 ProxBinaryRelax, ProxHardQuant, diff --git a/torchao/prototype/parq/optim/__init__.py b/torchao/prototype/parq/optim/__init__.py index 627bedb4dd..237a058a12 100644 --- a/torchao/prototype/parq/optim/__init__.py +++ b/torchao/prototype/parq/optim/__init__.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + from .binarelax import ProxBinaryRelax # noqa: F401 from .parq import ProxPARQ # noqa: F401 from .proxmap import ProxHardQuant, ProxMap # noqa: F401 diff --git a/torchao/prototype/parq/optim/binarelax.py b/torchao/prototype/parq/optim/binarelax.py index 0ce88d8ccb..2cc6611f6b 100644 --- a/torchao/prototype/parq/optim/binarelax.py +++ b/torchao/prototype/parq/optim/binarelax.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. + from typing import Optional import torch @@ -41,7 +42,6 @@ def apply_( if step_count >= self.anneal_end: p.copy_(q) - return else: # linear annealing of relaxation coefficient theta = (step_count - self.anneal_start) / ( diff --git a/torchao/prototype/parq/optim/parq.py b/torchao/prototype/parq/optim/parq.py index b756efd3ec..ade403a87d 100644 --- a/torchao/prototype/parq/optim/parq.py +++ b/torchao/prototype/parq/optim/parq.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. + import math from functools import partial from typing import Optional @@ -23,14 +24,16 @@ def amp_custom_fwd(cast_inputs: Optional[torch.types._dtype] = None): return partial(torch.cuda.amp.custom_fwd, cast_inputs=cast_inputs) -def normalized_mirror_sigmoid(t: float, t1: float, t2: float, s: float) -> float: +def normalized_mirror_sigmoid( + t: float, t1: float, t2: float, s: float, c: float +) -> float: """Sigmoid-like function decreasing from 1 to 0 over interval [t1, t2). s is steepness of the sigmoid-like function, almost linear for s < 1. 'mirror' means decreasing instead of increasing as true sigmoid, 'normalized' means value 1 at starting point t1 and 0 at end point t2.""" assert t >= t1 and t < t2, "Normalized sigmoid: ensure t1 <= t < t2" ft = (t - t1) / (t2 - t1) # fraction of progress from t1 to t2 - st = 1 / (1 + math.exp(s * (ft - 0.5))) # scaled and shifted mirror sigmoid + st = 1 / (1 + math.exp(s * (ft - c))) # scaled and shifted mirror sigmoid s1 = 1 / (1 + math.exp(-0.5 * s)) # st value when t = t1 -> ft = 0 s2 = 1 / (1 + math.exp(0.5 * s)) # st value when t = t2 -> ft = 1 return (st - s2) / (s1 - s2) # shift and scale to range (0, 1] @@ -38,13 +41,18 @@ def normalized_mirror_sigmoid(t: float, t1: float, t2: float, s: float) -> float class ProxPARQ(ProxMap): def __init__( - self, anneal_start: int, anneal_end: int, steepness: float = 10 + self, + anneal_start: int, + anneal_end: int, + steepness: float = 10, + anneal_center: float = 0.5, ) -> None: assert anneal_start < anneal_end, "PARQ annealing: start before end." assert steepness > 0, "PARQ annealing steepness should be positive." self.anneal_start = anneal_start self.anneal_end = anneal_end self.steepness = steepness + self.anneal_center = anneal_center @torch.no_grad() @amp_custom_fwd(cast_inputs=torch.float32) @@ -72,8 +80,13 @@ def apply_( p.copy_(q) else: inv_slope = normalized_mirror_sigmoid( - step_count, self.anneal_start, self.anneal_end, self.steepness + step_count, + self.anneal_start, + self.anneal_end, + self.steepness, + self.anneal_center, ) + inv_slope = max(torch.finfo(p.dtype).tiny, inv_slope) # it is important to clamp idx-1 and then clamping idx itself # idx_1[k] == idx[k] iff p[k] > Q.max() or p[k] <= Q.min() if dim is None: diff --git a/torchao/prototype/parq/optim/proxmap.py b/torchao/prototype/parq/optim/proxmap.py index 0bfbf57498..da867cc5db 100644 --- a/torchao/prototype/parq/optim/proxmap.py +++ b/torchao/prototype/parq/optim/proxmap.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. + from abc import ABC, abstractmethod from typing import Optional diff --git a/torchao/prototype/parq/optim/quantopt.py b/torchao/prototype/parq/optim/quantopt.py index 016aea28bc..7ebc1a80a0 100644 --- a/torchao/prototype/parq/optim/quantopt.py +++ b/torchao/prototype/parq/optim/quantopt.py @@ -3,6 +3,8 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. + +from collections import defaultdict from collections.abc import Callable from functools import partial from typing import Any, Optional @@ -11,15 +13,14 @@ from torch import Tensor from torch.optim import Optimizer -from ..quant import LSBQuantizer, Quantizer +from ..quant import Quantizer +from ..utils import HAS_DTENSOR, is_dtensor from .proxmap import ProxMap -try: - from torch.distributed.tensor import DTensor - - HAS_DTENSOR = True -except ImportError: - HAS_DTENSOR = False +if HAS_DTENSOR: + from torch.distributed.tensor import distribute_tensor + from torch.distributed.tensor.experimental import local_map + from torch.distributed.tensor.placement_types import Shard class QuantOptimizer(Optimizer): @@ -31,7 +32,7 @@ class QuantOptimizer(Optimizer): a proximal mapping (e.g, HardQuant/STE, PARQ, BinaryRelax) - update model parameters based on the above two updates Other parameters: - - warmup_steps: int > 0 + - warmup_steps: int >= 0 - quant_period: int > 0 - quant_per_channel: True or False - quant_shrink: True or False @@ -86,23 +87,23 @@ def __repr__(self) -> str: extra_repr = "\n ".join(("(", base_optimizer, f"{quantizer=}", f"{prox_map=}")) return f"{self.__class__.__name__} {extra_repr}\n)" + @property + def state(self) -> defaultdict[Tensor, Any]: # pyre-ignore[3] + return self._state if hasattr(self, "_state") else self.base_optimizer.state + @staticmethod def quantize_( p: Tensor, quants: Tensor, quantizer: Quantizer, b: int, - quant_update: bool, dim: Optional[int] = None, ) -> Optional[Tensor]: """Optionally update the quantization targets `quants` in place. Return the quantized `p` as a by-product if `quant_update=True`. """ - if quant_update: # update Q for each channel - q, Q = quantizer.quantize(p, b, dim=dim) # pyre-ignore[28] - quants.copy_(Q) - else: - q = None + q, Q = quantizer.quantize(p, b, dim=dim) # pyre-ignore[28] + quants.copy_(Q) return q def regularized_param_groups(self): # pyre-ignore[3] @@ -122,12 +123,13 @@ def state_dict(self) -> dict[str, Any]: def load_state_dict( self, state_dict: dict[str, Any], start_step: Optional[int] = None ) -> None: - qat_state = state_dict.pop("qat_state") + qat_state = state_dict.get("qat_state") # resume from check points usually not corresponds to saved num_steps # so allow explicit start_step computed from epochs * steps_per_epoc if start_step is not None: self.num_steps = start_step - else: # hope discrepancy in num_steps does not cause major problem! + elif qat_state is not None: + # hope discrepancy in num_steps does not cause major problem! self.num_steps = qat_state["num_steps"] self.base_optimizer.load_state_dict(state_dict) @@ -144,9 +146,6 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] self.num_steps += 1 return loss - # call base optimizer step() method to update latent parameters - loss = self.base_optimizer.step(closure=closure) # pyre-ignore[6] - if self.num_steps == self.warmup_steps: # first step of qat, save latent params, instead of restore self.save_latent_params() @@ -154,6 +153,16 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] # qat: restore latent params for update by the base optimizer self.restore_latent_params() + # call base optimizer step() method to update latent parameters + loss = self.base_optimizer.step(closure=closure) # pyre-ignore[6] + + if hasattr(self, "_state"): + assert self.warmup_steps == 0 + # restore the temporary state to the base optimizer's state + for p in self._state.keys(): + self.base_optimizer.state[p]["latent"] = self._state[p]["latent"] + del self._state + # check if it is time to update set of quantization values Q if (self.num_steps - self.warmup_steps) % self.quant_period == 0: quant_update = True @@ -165,6 +174,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] group["cumu_lr"] += group["lr"] gamma = max(1.0, group["cumu_lr"]) b = group["quant_bits"] + block_size = group.get("quant_block_size") inv_slope = 0.0 for p in group["params"]: if not p.requires_grad: @@ -177,44 +187,66 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] if self.quant_shrink: p.div_(gamma) + # reshape p according to block size if specified + if block_size is not None: + assert ( + p.size(-1) % block_size == 0 + ), f"{p.size(-1)=} is not divisible by {block_size=}" + assert p.dim() <= 2, f"Invalid {p.dim()=} for {block_size=}" + if p.dim() == 1: + p = p.unsqueeze(0) + + # row-major ordering ensures this is correct + p = p.view(-1, block_size) + # quantization by channel or by layer # update quantization targets periodically per_channel = self.quant_per_channel and p.dim() > 1 if quant_update: - quants_size = 3 if b == 0 else 2**b - if per_channel: - quants_size = (p.size(0), quants_size) - state["quants"] = torch.empty( - quants_size, device=p.device - ) # pyre-ignore[6] + quant_size = self.quantizer.get_quant_size(b) - # avoid type mismatch between sharded and full tensors - if HAS_DTENSOR and isinstance(p, DTensor): - p = p.full_tensor() + if per_channel: + quant_size = (p.size(0), quant_size) + state["quants"] = torch.empty(quant_size, device=p.device) + if is_dtensor(p): + state["quants"] = distribute_tensor( + state["quants"], + device_mesh=p.device_mesh, + placements=p.placements, + ) dim = -1 if per_channel else None if per_channel and p.dim() > 2: p = p.flatten(start_dim=1) - # NOTE: for LSBQ and optimal=False, use faster per-channel - # implementation instead of vmap - if isinstance(self.quantizer, LSBQuantizer) and self.quantizer.optimal: + q = None + if quant_update: qfunc = partial( - self.quantize_, - quantizer=self.quantizer, - b=b, - quant_update=quant_update, - ) - q = torch.vmap(qfunc, in_dims=0, out_dims=0)(p, state["quants"]) - else: - q = self.quantize_( - p, state["quants"], self.quantizer, b, quant_update, dim=dim + self.quantize_, quantizer=self.quantizer, b=b, dim=dim ) + if is_dtensor(p): + qfunc = local_map( + qfunc, + out_placements=[*p.placements], + in_placements=([Shard(0)], [Shard(0)]), + ) + q = qfunc(p, state["quants"]) # apply (step-dependent) proximal mapping in place - inv_slope = self.prox_map.apply_( # pyre-ignore[28] - p, q, state["quants"], self.num_steps, dim=dim + pfunc = partial( + self.prox_map.apply_, step_count=self.num_steps, dim=dim ) + if is_dtensor(p): + pfunc = local_map( + pfunc, + out_placements=None, + in_placements=( + [Shard(0)], + None if q is None else [Shard(0)], + [Shard(0)], + ), + ) + inv_slope = pfunc(p, q, state["quants"]) # quantized parameters share the same PARQ inverse slope if inv_slope: @@ -239,6 +271,12 @@ def restore_latent_params(self) -> None: @torch._disable_dynamo def save_latent_params(self) -> None: """Save updated latent parameters before applying prox-map""" + if self.warmup_steps == 0: + assert len(self.state) == 0, "Expected empty state at first step()" + # Maintain the invariant that `len(self.state) == 0` before first + # self.base_optimizer.step() call by using a temporary state buffer + self._state = defaultdict(dict) + for group in self.regularized_param_groups(): for p in group["params"]: if p.requires_grad: diff --git a/torchao/prototype/parq/quant/__init__.py b/torchao/prototype/parq/quant/__init__.py index b7251f2df1..8835740975 100644 --- a/torchao/prototype/parq/quant/__init__.py +++ b/torchao/prototype/parq/quant/__init__.py @@ -1,3 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + from .lsbq import LSBQuantizer # noqa: F401 from .quantizer import Quantizer # noqa: F401 -from .uniform import UnifQuantizer # noqa: F401 +from .uniform import ( # noqa: F401 + MaxUnifQuantizer, + TernaryUnifQuantizer, + UnifQuantizer, +) diff --git a/torchao/prototype/parq/quant/lsbq.py b/torchao/prototype/parq/quant/lsbq.py index e821b8f460..2d9f4e4c1e 100644 --- a/torchao/prototype/parq/quant/lsbq.py +++ b/torchao/prototype/parq/quant/lsbq.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. + import itertools from collections.abc import Iterable from typing import Optional @@ -23,10 +24,58 @@ def binary_quant_residue(u: Tensor, vs: Iterable[float]) -> Tensor: """Return residue for foldable binary quantization""" r = u.detach().clone() for v in vs: - r -= v * binary_sign(r) + r.sub_(v * binary_sign(r)) return r +def compute_v_per_channel(p: Tensor, dim: Optional[int] = None, ternary: bool = False): + """Vectorized computation of optimal `v` for ternary/2-bit algorithm.""" + v_cands = p.abs().sort(dim=dim).values + cumsum = v_cands.cumsum(dim=dim) + cumsum, total_sum = cumsum[:, 1:-1], cumsum[:, -1:] + + # compute cumulative mean from right to left + counts = torch.arange(1, p.size(dim=dim), device=p.device) + counts_r2l = counts[:-1].flip((-1,)) + cmean_r2l = (total_sum - cumsum).div_(counts_r2l.mul_(2)) + v_cands, v_cands2 = v_cands[:, 1:-1], v_cands[:, 2:] + + # mask to estimate conditional expectation + mask = (v_cands <= cmean_r2l).logical_and_(v_cands2 >= cmean_r2l) + if ternary: + # detect and fix any edge cases + optimal_v = p.mean(dim=dim, keepdim=True).div_(2) + row_invalid = optimal_v < p.min(dim=dim, keepdim=True).values + if row_invalid.any(): + extra_col = row_invalid.to(p.dtype).mul(optimal_v) + v_cands = torch.cat((v_cands, extra_col), -1) + mask = torch.cat((mask, row_invalid), -1) + else: + # compute cumulative mean from left to right + cmean_l2r = cumsum.div_(counts[1:].mul_(2)).add_(cmean_r2l) + mask.logical_or_((v_cands <= cmean_l2r).logical_and_(v_cands2 >= cmean_l2r)) + + # handle variable number of candidates per channel + split_sizes = mask.sum(dim=dim).tolist() + v_cands = v_cands[mask].split(split_sizes) + v_cands = torch.nested.nested_tensor(list(v_cands)) + v_cands = torch.nested.to_padded_tensor(v_cands, 0.0) + + # update residual for each candidate `v` + r = p.unsqueeze(dim - 1) + v = v_cands.unsqueeze(-1) + r = r.sub(v * binary_sign(r)) + if not ternary: + v = v.mean(dim=dim, keepdim=True) + r = r.sub(v * binary_sign(r)) + + # compute least squares error, then select the `v` minimizes it + costs = r.norm(dim=dim) + indices = costs.argmin(dim=dim, keepdim=True) + v_best = v_cands.gather(1, indices) + return v_best + + class LSBQuantizer(Quantizer): """Least-Square Binary Quantizer, using greedy algorithm by default. Optimal solution available for three cases: b=1, b=2 and ternary. @@ -44,25 +93,31 @@ def __init__( self.optimal = optimal self.ternary_multiplier = ternary_multiplier + def get_quant_size(self, b: int) -> int: + return 2**b if b > 0 else 3 + def quantize( self, p: Tensor, b: int, dim: Optional[int] = None ) -> tuple[Tensor, Tensor]: """Instantiation of Quantizer.quantize(), with b=0 for ternary""" - assert b >= 0 # b==0 means ternary + if b < 0: + raise ValueError(f"Invalid {b=}; must be nonnegative") + if self.optimal and b > 2: + raise NotImplementedError(f"Unsupported {self.optimal=} for {b=}") + if self.center: q, mean = super().remove_mean(p.detach(), dim=dim) else: q = p.detach().clone() mean = torch.zeros(1, dtype=p.dtype, device=p.device) - # b == 0 means ternary; b == 1 optimal same as greedy - if b == 0: - if self.optimal: - q, Q = self.quantize_optimal_ternary(q) - else: - q, Q = self.quantize_simple_ternary(q, self.ternary_multiplier, dim=dim) - elif b == 2 and self.optimal: - q, Q = self.quantize_optimal_2bits(q) + if self.optimal and b != 1: # b == 1 optimal is the same as greedy + if b == 0: + q, Q = self.quantize_optimal_ternary(q, dim=dim) + elif b == 2: + q, Q = self.quantize_optimal_2bits(q, dim=dim) + elif b == 0: + q, Q = self.quantize_simple_ternary(q, self.ternary_multiplier, dim=dim) else: q, Q = self.quantize_greedy(q, b, dim=dim) @@ -81,7 +136,7 @@ def quantize_greedy( keepdim = dim is not None for _ in range(b): v = r.abs().mean(dim=dim, keepdim=keepdim) - r -= v * binary_sign(r) + r.sub_(binary_sign(r).mul_(v)) vs.append(v) q = p - r @@ -90,16 +145,32 @@ def quantize_greedy( B = torch.tensor(basis, dtype=p.dtype, device=p.device) if dim is not None: V = torch.concat(vs, dim=1) # [dim0, b] - Q = torch.sort(V @ B.T, dim=dim)[0] # [dim0, 2^b] + Q = torch.sort(V @ B.T, dim=dim).values # [dim0, 2^b] else: V = torch.tensor(vs, dtype=p.dtype, device=p.device) Q = torch.msort(B.matmul(V)) # [2^b] return q, Q @staticmethod - def quantize_optimal_2bits(p: Tensor) -> tuple[Tensor, Tensor]: + def quantize_optimal_2bits( + p: Tensor, dim: Optional[int] = None + ) -> tuple[Tensor, Tensor]: + # generate 4 x 2 basis tensor B, sorted lexicographically along dim 0 + basis = list(itertools.product((-1, 1), repeat=2)) + B = torch.tensor(basis, dtype=p.dtype, device=p.device) + if dim is not None: + v1 = compute_v_per_channel(p, dim=dim, ternary=False) + s = binary_sign(p).mul_(v1) + r = p.sub(s) + v2 = r.abs().mean(dim=dim, keepdim=True) + q = s.add_(binary_sign(r).mul_(v2)) + + V = torch.cat((v1, v2), dim=-1) # [dim0, b] + Q = V @ B.T # [dim0, 2^b] + return q, Q + # first form the cumulative sum of sorted absolute values of p - p_abs_sorted = torch.msort(torch.flatten(p.abs())) + p_abs_sorted = p.abs().flatten().sort().values cumsum = torch.cumsum(p_abs_sorted, dim=0) n = cumsum.numel() # find all solutions v1 to an inclusion problem (after sorting |p|) @@ -133,18 +204,31 @@ def quantize_optimal_2bits(p: Tensor) -> tuple[Tensor, Tensor]: min_error = error q = p - r v1, v2 = v1v2 - # generate 4 x 2 basis tensor B, sorted lexicographically along dim 0 - basis = list(itertools.product((-1, 1), repeat=2)) - B = torch.tensor(basis, dtype=p.dtype, device=p.device) - # vmap workaround: calling torch.tensor on v1, v2 raises an error - Q = v1 * B[:, 0] + v2 * B[:, 1] + + V = torch.tensor((v1, v2), dtype=p.dtype, device=p.device) + Q = B @ V return q, Q @staticmethod - def quantize_optimal_ternary(p: Tensor) -> tuple[Tensor, Tensor]: + def quantize_optimal_ternary( + p: Tensor, dim: Optional[int] = None + ) -> tuple[Tensor, Tensor]: """Formula look reasonable, but derivation in reference incorrect?""" + if dim is not None: + v = compute_v_per_channel(p, dim=dim, ternary=True) + p_sign = binary_sign(p) + r = p.sub(p_sign.mul(v)) + + # 0 if sign(p) != sign(r), else sign(p) * 2v + q = p_sign.add_(binary_sign(r)).mul_(v) + + # each channel can take values [-2v, 0, 2v] + v.mul_(2) + Q = torch.cat((-v, torch.zeros_like(v), v), dim=-1) # [dim0, 3] + return q, Q + # first form the cumulative sum of sorted absolute values of p - p_abs_sorted = torch.msort(torch.flatten(p.abs())) + p_abs_sorted = p.abs().flatten().sort().values cumsum = torch.cumsum(p_abs_sorted, dim=0) n = cumsum.numel() # find all solutions v1 to an inclusion problem (after sorting |p|) diff --git a/torchao/prototype/parq/quant/quantizer.py b/torchao/prototype/parq/quant/quantizer.py index 22dd2a1bb7..b44050e773 100644 --- a/torchao/prototype/parq/quant/quantizer.py +++ b/torchao/prototype/parq/quant/quantizer.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. + from abc import ABC, abstractmethod from typing import Optional @@ -15,6 +16,10 @@ class Quantizer(ABC): def __init__(self, center: bool = False) -> None: self.center = center + @abstractmethod + def get_quant_size(self, b: int) -> int: + """Given number of bits b, return total number of quantization values""" + @abstractmethod def quantize(self, p: Tensor, b: int) -> tuple[Tensor, Tensor]: """Provide interface for quantization: diff --git a/torchao/prototype/parq/quant/uniform.py b/torchao/prototype/parq/quant/uniform.py index de9e465bc0..f264894115 100644 --- a/torchao/prototype/parq/quant/uniform.py +++ b/torchao/prototype/parq/quant/uniform.py @@ -3,6 +3,8 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. + +import math from typing import Optional import torch @@ -11,50 +13,202 @@ from .quantizer import Quantizer +def get_q_max( + q: Tensor, b: int, dim: Optional[int] = None, scale_method: str = "mean" +) -> Tensor: + if scale_method == "mean": + # set range of quantization: min(b * |q|.mean(), |q|.max()) + q_abs = q.abs() + if dim is not None: + q_max = torch.minimum( + b * q_abs.mean(dim=dim, keepdim=True), # pyre-ignore[6,9] + torch.max(q_abs, dim=dim, keepdim=True).values, # pyre-ignore[6] + ) + else: + q_max = torch.minimum(b * q_abs.mean(), torch.max(q_abs)) # pyre-ignore[6] + elif scale_method == "max": + q_max = ( + q.abs().max(dim=dim, keepdim=True).values + if dim is not None + else q.abs().max() + ) + else: + raise NotImplementedError(f"Invalid {scale_method=}, choices=('mean','max')") + return q_max + + class UnifQuantizer(Quantizer): - """Uniform quantizer, range determined by multiples of |p|.mean()""" + """Uniform and symmetric quantizer""" + + def __init__( + self, + center: bool = False, + scale_method: str = "mean", + int_shift: float = 0.5, + zero_point: float = 0.5, + ): + """Set quantization function parameters. + + Args: + center: whether to subtract p.mean() prior to quantization + scale_method: compute scale based 'mean', multiples of |p|.mean(), + or 'max', |p|.max() (default: 'mean') + int_shift: float value to shift the lower bound of integer range by: + -2^{b - 1} + int_shift (default: 0.5). Using 0.5 results in 2^b + values. E.g., [-1.5, -0.5, 0.5, 1.5] for b=2. + zero_point: float value to shift p by after scale and round. + """ + assert scale_method in ("max", "mean"), f"Invalid {scale_method=}" + super().__init__(center=center) + + self.scale_method = scale_method + self.int_shift = int_shift + self.zero_point = zero_point + + def get_quant_size(self, b: int) -> int: + """Levels in [-2^{b-1} + self.int_shift, 2^{b-1} - self.int_shift]. - def __init__(self, center: bool = False) -> None: - super().__init__(center) + Note that range_absmax = 2^{b-1} - self.int_shift on both ends of the + boundary and the interval is closed.""" + return math.floor(2**b - 2 * self.int_shift) + 1 def quantize( self, p: Tensor, b: int, dim: Optional[int] = None ) -> tuple[Tensor, Tensor]: """Instantiation of Quantizer.quantize() method""" - assert b >= 1 + assert b != 0, "Please use TernaryUnifQuantizer instead" + if self.center: q, mean = super().remove_mean(p.detach(), dim=dim) else: q = p.detach().clone() mean = torch.zeros(1, dtype=p.dtype, device=p.device) - - # set range of quantization: min( b * |q|.mean(), |q|.max()) - q_abs = q.abs() - if dim is not None: - q_max = torch.minimum( - b * q_abs.mean(dim=dim, keepdim=True), # pyre-ignore[6,9] - torch.max(q_abs, dim=dim, keepdim=True)[0], # pyre-ignore[6] - ) - else: - q_max = torch.minimum(b * q_abs.mean(), torch.max(q_abs)) # pyre-ignore[6] + q_max = get_q_max(q, b, dim=dim, scale_method=self.scale_method) + q_max.clamp_(min=torch.finfo(q.dtype).tiny) # clamp to quantization range q.copy_(torch.minimum(torch.maximum(q, -q_max), q_max)) - # compute scale from [-2^{b-1}+0.5, 2^{b-1}-0.5] to [-q_max, q_max] - s = q_max / (2 ** (b - 1) - 0.5) + # scale from [-2^{b-1}+int_shift, 2^{b-1}-int_shift] to [-q_max, q_max] + range_absmax = 2 ** (b - 1) - self.int_shift + s = q_max / range_absmax - # scale by 1/s -> shift -0.5 -> round -> shift +0.5 -> scale by s - # where shift ensures rounding to integers 2^{b-1}, ..., 2^{b-1}-1 - q.div_(s).sub_(0.5).round_().add_(0.5).mul_(s) + # scale by 1/s -> shift -zero_point -> round -> shift +zero_point -> + # scale by s, where shift ensures rounding to integers + q.div_(s).sub_(self.zero_point).round_().add_(self.zero_point).mul_(s) # set of all target quantization values - Q = s * ( - torch.arange(-(2 ** (b - 1)) + 0.5, 2 ** (b - 1), step=1, device=q.device) + Q = torch.arange( + -range_absmax, range_absmax + 1e-5, dtype=p.dtype, device=p.device ) + if dim is not None: + Q = Q.unsqueeze(0).mul(s) # broadcasted multiply requires copy + else: + Q.mul_(s) + + # return quantized tensor and set of possible quantization values + if self.center: + q += mean + Q += mean + return q, Q + + +class MaxUnifQuantizer(UnifQuantizer): + def __init__( + self, + center: bool = False, + scale_method: str = "max", + int_shift: float = 1.0, + zero_point: float = 0.0, + ): + """Set quantization function with int_shift=1.0. + + The final quantization range includes 2^b - 1 quantized values. E.g., + [-1, 0, 1] for b=2. The quantization scale is determined by |p|.max() + by default and zero point is 0.0. + """ + super().__init__( + center=center, + scale_method=scale_method, + int_shift=int_shift, + zero_point=zero_point, + ) + + +class AsymUnifQuantizer(Quantizer): + def get_quant_size(self, b: int) -> int: + """Equivalent to int_max - int_min + 1, where int_min = -2^{b-1} and + int_max = 2^{b-1} - 1.""" + return 2**b + + def quantize( + self, p: Tensor, b: int, dim: Optional[int] = None + ) -> tuple[Tensor, Tensor]: + assert b != 0, "Please use TernaryUnifQuantizer instead" + + if self.center: + q, mean = super().remove_mean(p.detach(), dim=dim) + else: + q = p.detach().clone() + mean = torch.zeros(1, dtype=p.dtype, device=p.device) + + if dim is not None: + q_min = q.min(dim=dim, keepdim=True).values + q_max = q.max(dim=dim, keepdim=True).values + else: + q_min = q.min() + q_max = q.max() + + int_min = -(2 ** (b - 1)) + int_max = 2 ** (b - 1) - 1 + s = (q_max - q_min) / (int_max - int_min) + s.clamp_(min=torch.finfo(q.dtype).tiny) + + zero_point = q_min.div_(s).round_() + q.div_(s).round_().sub_(zero_point).add_(zero_point).mul_(s) + + Q = torch.arange(int_min, int_max + 1, dtype=p.dtype, device=p.device) + if dim is not None: + Q = Q.unsqueeze(0).mul(s) # broadcasted multiply requires copy + else: + Q.mul_(s) # return quantized tensor and set of possible quantization values if self.center: q += mean Q += mean return q, Q + + +class TernaryUnifQuantizer(Quantizer): + """Uniform quantizer for ternary bit case. Quantization range is [-1, 1].""" + + def get_quant_size(self, b: int) -> int: + return 3 + + def quantize( + self, p: Tensor, b: int, dim: Optional[int] = None + ) -> tuple[Tensor, Tensor]: + assert b == 0, f"Unexpected {b=} for ternary case" + + if self.center: + q, mean = super().remove_mean(p.detach(), dim=dim) + else: + q = p.detach().clone() + mean = torch.zeros(1, dtype=p.dtype, device=p.device) + + q_max = get_q_max(q, b, dim=dim, scale_method="max") + q_max.clamp_(min=torch.finfo(q.dtype).tiny) + s = q_max / 1.5 + q.div_(s).round_().clamp_(min=-1, max=1).mul_(s) + + Q = torch.tensor([-1, 0, 1], dtype=p.dtype, device=p.device) + if dim is not None: + Q = Q.unsqueeze(0).mul(s) + else: + Q.mul_(s) + + if self.center: + q += mean + Q += mean + return q, Q diff --git a/torchao/prototype/parq/utils.py b/torchao/prototype/parq/utils.py index d18257574f..ac5024fb5d 100644 --- a/torchao/prototype/parq/utils.py +++ b/torchao/prototype/parq/utils.py @@ -3,9 +3,21 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. + import torch from torch import Tensor +try: + from torch.distributed.tensor import DTensor + + HAS_DTENSOR = True +except ImportError: + HAS_DTENSOR = False + + +def is_dtensor(x): + return HAS_DTENSOR and isinstance(x, DTensor) + def channel_bucketize(input: Tensor, boundaries: Tensor, right: bool = False) -> Tensor: """Generalizes torch.bucketize to run on 2-D boundaries.""" @@ -18,4 +30,4 @@ def channel_bucketize(input: Tensor, boundaries: Tensor, right: bool = False) -> boundaries = boundaries.unsqueeze(1) input = input.unsqueeze(-1) mask = input.ge(boundaries) if right else input.le(boundaries) - return mask.int().argmax(dim=-1) + return mask.to(torch.uint8).argmax(dim=-1) From ee2b9c7818ab15f8762e416ee4a82de005705cc9 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Fri, 4 Apr 2025 15:30:31 -0700 Subject: [PATCH 24/30] Use quantized gemm only on aarch64 Differential Revision: D72413684 Pull Request resolved: https://github.com/pytorch/ao/pull/2023 --- ...bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h | 4 ++-- ...it_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h | 4 ++-- ...fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h | 4 ++-- .../experimental/kernels/cpu/aarch64/matmul/matmul.h | 4 ++-- .../kernels/cpu/interface/quantized_matmul.h | 12 ++++++------ 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h b/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h index b83c28143f..ecdc40f880 100644 --- a/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h +++ b/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h @@ -6,7 +6,7 @@ #pragma once -#if defined(__aarch64__) || defined(__ARM_NEON) +#if defined(__aarch64__) && defined(__ARM_NEON) #include #include @@ -381,4 +381,4 @@ void kernel( } // namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal } // namespace torchao::kernels::cpu::aarch64::quantized_matmul -#endif // defined(__aarch64__) || defined(__ARM_NEON) +#endif // defined(__aarch64__) && defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h b/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h index 123b7723e4..898fa30b18 100644 --- a/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h +++ b/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h @@ -6,7 +6,7 @@ #pragma once -#if defined(__aarch64__) || defined(__ARM_NEON) +#if defined(__aarch64__) && defined(__ARM_NEON) #include #include @@ -333,4 +333,4 @@ void kernel( } // namespace channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot } // namespace torchao::kernels::cpu::aarch64::quantized_matmul -#endif // defined(__aarch64__) || defined(__ARM_NEON) +#endif // defined(__aarch64__) && defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h b/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h index bdad1b4a47..663132b35d 100644 --- a/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h +++ b/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h @@ -6,7 +6,7 @@ #pragma once -#if defined(__aarch64__) || defined(__ARM_NEON) +#if defined(__aarch64__) && defined(__ARM_NEON) #include #include @@ -278,4 +278,4 @@ void kernel( } // namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32 } // namespace torchao::kernels::cpu::aarch64::quantized_matmul -#endif // defined(__aarch64__) || defined(__ARM_NEON) +#endif // defined(__aarch64__) && defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h b/torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h index 43f3dd4bce..da3a41fdb9 100644 --- a/torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h +++ b/torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h @@ -10,7 +10,7 @@ #pragma once -#if defined(__aarch64__) || defined(__ARM_NEON) +#if defined(__aarch64__) && defined(__ARM_NEON) #include @@ -92,4 +92,4 @@ void kernel( #include #include -#endif // defined(__aarch64__) || defined(__ARM_NEON) +#endif // defined(__aarch64__) && defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/interface/quantized_matmul.h b/torchao/experimental/kernels/cpu/interface/quantized_matmul.h index 718f7eaad9..d9c9d23271 100644 --- a/torchao/experimental/kernels/cpu/interface/quantized_matmul.h +++ b/torchao/experimental/kernels/cpu/interface/quantized_matmul.h @@ -11,11 +11,11 @@ #include #include -#if defined(__aarch64__) || defined(__ARM_NEON) +#if defined(__aarch64__) && defined(__ARM_NEON) #include #include #include -#endif // defined(__aarch64__) || defined(__ARM_NEON) +#endif // defined(__aarch64__) && defined(__ARM_NEON) namespace torchao::kernels::cpu::quantized_matmul { @@ -67,7 +67,7 @@ get_int8_a_int8_b_channelwise_qmatmul( bool b_transposed, int& a_stride_m, int& b_stride_n) { -#if defined(__aarch64__) || defined(__ARM_NEON) +#if defined(__aarch64__) && defined(__ARM_NEON) if (!a_transposed && b_transposed && n >= 8) { a_stride_m = k; b_stride_n = k; @@ -75,7 +75,7 @@ get_int8_a_int8_b_channelwise_qmatmul( channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot:: kernel; } -#endif // defined(__aarch64__) || defined(__ARM_NEON) +#endif // defined(__aarch64__) && defined(__ARM_NEON) assert(!a_transposed); if (b_transposed) { a_stride_m = k; @@ -134,14 +134,14 @@ get_fp32_a_input_channelwise_8bit_b_f32_c_matmul( bool b_transposed, int& a_stride_m, int& b_stride_n) { -#if defined(__aarch64__) || defined(__ARM_NEON) +#if defined(__aarch64__) && defined(__ARM_NEON) if (!a_transposed && !b_transposed && n >= 16) { a_stride_m = k; b_stride_n = n; return aarch64::quantized_matmul:: fp32_a_input_channelwise_8bit_b_1x16x4_f32::kernel; } -#endif // defined(__aarch64__) || defined(__ARM_NEON) +#endif // defined(__aarch64__) && defined(__ARM_NEON) assert(!a_transposed); if (b_transposed) { a_stride_m = k; From 05ae22cdf56babde0d1642856ddbc713f4d83733 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Sat, 5 Apr 2025 16:35:28 -0700 Subject: [PATCH 25/30] Adds utility to replace Q/DQ ops with torchao quantized linear ops (#1967) * up * up * up * up --- .../workflows/torchao_experimental_test.yml | 1 + torchao/experimental/quant_passes.py | 217 ++++++++++++++++++ .../experimental/tests/test_quant_passes.py | 82 +++++++ 3 files changed, 300 insertions(+) create mode 100644 torchao/experimental/quant_passes.py create mode 100644 torchao/experimental/tests/test_quant_passes.py diff --git a/.github/workflows/torchao_experimental_test.yml b/.github/workflows/torchao_experimental_test.yml index 2187eed8e3..4d0a1eaaf6 100644 --- a/.github/workflows/torchao_experimental_test.yml +++ b/.github/workflows/torchao_experimental_test.yml @@ -44,6 +44,7 @@ jobs: conda activate venv pytest torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py python torchao/experimental/tests/test_embedding_xbit_quantizer.py + python torchao/experimental/tests/test_quant_passes.py - name: Run kernels/cpu/aarch64/tests run: | conda activate venv diff --git a/torchao/experimental/quant_passes.py b/torchao/experimental/quant_passes.py new file mode 100644 index 0000000000..9a744643c8 --- /dev/null +++ b/torchao/experimental/quant_passes.py @@ -0,0 +1,217 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import itertools +from collections import defaultdict +from typing import Callable, Optional + +import torch +from torch._export.passes.constant_folding import ( + ConstantFolder, + replace_node_with_constant, +) +from torch.fx import subgraph_rewriter + + +def constant_fold( + gm: torch.fx.GraphModule, + constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None, + skip_constructors: bool = False, +): + with torch.utils._python_dispatch._disable_current_modes(): + # The ConstantFolder has a bug where it throws if dequantize_affine is not defined + # TODO: fix upstream + try: + getattr(torch.ops.pt2e_quant, "dequantize_affine") + except AttributeError: + setattr(torch.ops.pt2e_quant, "dequantize_affine", None) + + cf = ConstantFolder(gm, skip_constructors) + cf.run() + + for node, constant in cf.node_replacements.items(): + if constraint_fn is not None and not constraint_fn(node): + continue + replace_node_with_constant(gm, node, constant) + + erased_params = [] + # Get all attr users by looking up the graph instead from node.users, because in this case + # _tensor_constant0 and _tensor_constant0_1 are actually refereing to the same tensor. + + # opcode name target args kwargs + # ------------- ------------------- ---------------- --------------------------- -------- + # placeholder arg0_1 arg0 () {} + # get_attr _tensor_constant0 state () {} + # call_function add aten.add.Tensor (arg0_1, _tensor_constant0) {} + # get_attr _tensor_constant0_1 state () {} + # call_function add_ aten.add_.Tensor (_tensor_constant0_1, 1) {} + # output output output ([add],) {} + + get_attr_node_users = defaultdict(list) + for node in gm.graph.nodes: + if node.op == "get_attr": + get_attr_node_users[node.target].extend(node.users.keys()) + for node in gm.graph.find_nodes(op="get_attr"): + if node.op == "get_attr" and len(get_attr_node_users[node.target]) == 0: + if hasattr(gm, node.target): + delattr(gm, node.target) + erased_params.append(node) + for node in erased_params: + gm.graph.erase_node(node) + + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + + +def _get_q_dq_linear_patterns_replacements_and_filters( + weight_bit_width, has_weight_zeros, target +): + glbs = globals() + glbs["weight_bit_width"] = weight_bit_width + glbs["target"] = target + glbs["w_quant_min"] = -(1 << (weight_bit_width - 1)) + glbs["w_quant_max"] = (1 << (weight_bit_width - 1)) - 1 + glbs["a_quant_min"] = -128 + glbs["a_quant_max"] = 127 + glbs["a_mapping_type"] = "ASYMMETRIC" + glbs["a_scale_dtype"] = torch.float32 + glbs["a_eps"] = None + + lcls = {} + + pattern_str = f""" +def pattern( + a, a_block_size, a_target_dtype, a_zero_point_dtype, + w_int_data, w_block_size, w_scale, w_zero_point, w_target_dtype, + bias): + a_scale, a_zero_point = torch.ops.quant.choose_qparams_affine.default( + a, + a_mapping_type, + a_block_size, + a_target_dtype, + a_quant_min, + a_quant_max, + a_eps, + a_scale_dtype, + a_zero_point_dtype, + ) + a_int_data = torch.ops.quant.quantize_affine.default( + a, a_block_size, a_scale, a_zero_point, a_target_dtype, a_quant_min, a_quant_max, + ) + dq_a = torch.ops.quant.dequantize_affine.default( + a_int_data, a_block_size, a_scale, a_zero_point, a_target_dtype, a_quant_min, a_quant_max + ) + dq_w = torch.ops.quant.dequantize_affine.default( + w_int_data, + w_block_size, + w_scale, + w_zero_point, + w_target_dtype, + w_quant_min, + w_quant_max, + {"'INT'" if has_weight_zeros else "'NONE'"} + ) + return torch.ops.aten.linear.default(dq_a, dq_w, bias) +""" + exec(pattern_str, glbs, lcls) + pattern = lcls["pattern"] + + replacement_str = f""" +def replacement( + a, a_block_size, a_target_dtype, a_zero_point_dtype, + w_int_data, w_block_size, w_scale, w_zero_point, w_target_dtype, + bias,): + n = w_int_data.size(0) + k = a_block_size[-1] + group_size = w_block_size[-1] + out_shape = a.shape[:-1] + (n,) + packed_weight = getattr( + torch.ops.torchao, + f"_pack_8bit_act_{weight_bit_width}bit_weight", + )( + w_int_data.to(torch.int8), + w_scale.reshape(-1), + {"w_zero_point.reshape(-1).to(torch.int8)" if has_weight_zeros else "None"}, + group_size, + bias, + target, + ) + return getattr( + torch.ops.torchao, f"_linear_8bit_act_{weight_bit_width}bit_weight" + )(a.reshape(-1, k), packed_weight, group_size, n, k).reshape(out_shape) +""" + + exec(replacement_str, glbs, lcls) + replacement = lcls["replacement"] + + def match_filter(match, x, y): + def get_val(name): + node = [n for n in match.nodes_map if n.name == name][0] + return match.nodes_map[node] + + int_types = [torch.int8, torch.int16, torch.int32, torch.int64] + + a_target_dtype = get_val("a_target_dtype") + if a_target_dtype not in int_types: + return False + + a_zero_point_dtype = get_val("a_zero_point_dtype") + if a_zero_point_dtype not in int_types: + return False + + # We only want a_block_size with shape [1, ..., 1, k] + a_block_size = get_val("a_block_size") + for d in a_block_size[0:-1]: + if d != 1: + print("a_block_size not [1, ..., 1, k]") + return False + + # We only want w_block_size with shape [1, group_size] + w_block_size = get_val("w_block_size") + if len(w_block_size) != 2 or w_block_size[0] != 1: + return False + + return True + + return pattern, replacement, match_filter + + +def replace_q_dq_patterns_with_quantized_linear_ops_pass( + ep: torch.export.ExportedProgram, + target=None, +) -> torch.export.ExportedProgram: + """ + This replaces Q/DQ patterns with torchao quantized linear ops. + It is intended for converting Q/DQ nodes exported with QDQLayout to using + the lowbit quantized linear ops. + """ + # TODO: figure out how to do this with dynamic_shapes (not saved on EP for easy re-export) + # See https://fb.workplace.com/groups/1028545332188949/permalink/1185289956514485/ + assert ( + len(ep.range_constraints) == 0 + ), "ExportedProgram with range constraints are not supported" + + # ep.module() unlifts the weight inputs, which we need for constant folding + gm = ep.module() + for weight_bit_width, has_weight_zeros in itertools.product( + range(1, 9), [True, False] + ): + pattern, replacement, match_filter = ( + _get_q_dq_linear_patterns_replacements_and_filters( + weight_bit_width, has_weight_zeros, target + ) + ) + subgraph_rewriter.replace_pattern_with_filters( + gm, pattern, replacement, match_filters=[match_filter] + ) + + # Constant fold evaluates and removes the packing ops + constant_fold(gm) + + # Re-export + return torch.export.export(gm, *ep.example_inputs) diff --git a/torchao/experimental/tests/test_quant_passes.py b/torchao/experimental/tests/test_quant_passes.py new file mode 100644 index 0000000000..3262e2bf7b --- /dev/null +++ b/torchao/experimental/tests/test_quant_passes.py @@ -0,0 +1,82 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from torch.testing import FileCheck + +from torchao.experimental.q_dq_layout import QDQLayout +from torchao.experimental.quant_api import ( + Int8DynamicActivationIntxWeightConfig, +) +from torchao.experimental.quant_passes import ( + replace_q_dq_patterns_with_quantized_linear_ops_pass, +) +from torchao.quantization.granularity import PerGroup, PerRow +from torchao.quantization.quant_api import quantize_ + + +class TestQuantPasses(unittest.TestCase): + def test_replace_q_dq_patterns_with_quantized_linear_ops_pass(self): + layers = [] + layer_to_weight_dtype = {} + layer_to_has_weight_zeros = {} + for weight_dtype in [getattr(torch, f"int{i}") for i in range(1, 9)]: + for has_weight_zeros in [True, False]: + for has_bias in [True, False]: + idx = len(layers) + layer_to_weight_dtype[idx] = weight_dtype + layer_to_has_weight_zeros[idx] = has_weight_zeros + layers.append(torch.nn.Linear(64, 64, bias=has_bias)) + activations = torch.randn(2, 1, 64, dtype=torch.float32) + + model = torch.nn.Sequential(*layers) + for idx in range(len(layers)): + quantize_( + model, + Int8DynamicActivationIntxWeightConfig( + weight_dtype=layer_to_weight_dtype[idx], + # Test out different granularities + granularity=PerGroup(32) if idx % 2 == 0 else PerRow(), + has_weight_zeros=layer_to_has_weight_zeros[idx], + layout=QDQLayout(), + ), + lambda m, fqn: fqn == str(idx), + ) + + eager_results = model(activations) + exported = torch.export.export(model, (activations,), strict=True) + exported = replace_q_dq_patterns_with_quantized_linear_ops_pass(exported) + + # We should not find pack op because it gets constant folded + FileCheck().check_not("torch.ops.torchao._pack_8bit_act").run( + exported.graph_module.code + ) + + # We should find len(layers) torchao linear ops + FileCheck().check_count( + "torch.ops.torchao._linear_8bit_act_", count=len(layers), exactly=True + ).run(exported.graph_module.code) + + # We should not find Q/DQ ops + FileCheck().check_not("torch.ops.quant.quantize_affine.default").run( + exported.graph_module.code + ) + FileCheck().check_not("torch.ops.quant.dequantize_affine.default").run( + exported.graph_module.code + ) + FileCheck().check_not("torch.ops.quant.choose_qparams_affine.default").run( + exported.graph_module.code + ) + + # Numerics should match + exported_results = exported.module()(activations) + self.assertTrue(torch.allclose(exported_results, eager_results)) + + +if __name__ == "__main__": + unittest.main() From e6f52ff2cc324bc3ea84d4f4074a9495d0ffd4ca Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Sat, 5 Apr 2025 17:05:06 -0700 Subject: [PATCH 26/30] Fix slice and padding for TensorCoreTiledLayout (#2015) * Fix slice and padding for TensorCoreTiledLayout for int4 weight only quantization Summary: Previously some of the code paths are not exercised, so the bug was not discovered but there are some bug related to slice operation and padding, basically scale and zero_point are not padded before, this results in errors when it is required. Test Plan: python test/dtypes/test_affine_quantized.py -k test_slice Reviewers: Subscribers: Tasks: Tags: * skip if no cuda * update callsites for post_process * add back missing post process * adding missing arg for floatx --- test/dtypes/test_affine_quantized.py | 16 +++++- torchao/dtypes/affine_quantized_tensor.py | 10 ++-- torchao/dtypes/uintx/gemlite_layout.py | 25 ++++++++-- torchao/dtypes/uintx/int4_cpu_layout.py | 31 +++++++----- torchao/dtypes/uintx/marlin_qqq_tensor.py | 1 - ...8_dynamic_activation_intx_weight_layout.py | 1 - .../dtypes/uintx/tensor_core_tiled_layout.py | 49 +++++++++++++------ torchao/dtypes/uintx/uintx_layout.py | 10 +++- torchao/dtypes/utils.py | 10 +++- torchao/testing/utils.py | 14 ++++++ 10 files changed, 126 insertions(+), 41 deletions(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 40a46b7e1b..4064bff535 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -17,6 +17,7 @@ from torchao.core.config import AOBaseConfig from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout from torchao.quantization import ( + Int4WeightOnlyConfig, Int8DynamicActivationInt8WeightConfig, float8_weight_only, int4_dynamic_activation_int4_weight, @@ -27,7 +28,7 @@ quantize_, ) from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain -from torchao.testing.utils import skip_if_rocm +from torchao.testing.utils import skip_if_no_cuda, skip_if_rocm from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, @@ -307,6 +308,19 @@ def test_alias(self, device, dtype): quantize_(dummy, Int8DynamicActivationInt8WeightConfig()) _ = dummy.weight[...] + @common_utils.parametrize("device", ["cuda"]) + @common_utils.parametrize("dtype", [torch.bfloat16]) + @skip_if_no_cuda() + def test_slice(self, device, dtype): + # in_feature not divisible by 1024 + # out_feature not divisible by 8 + # to test slice + padding for int4 weight only quantization + dummy = nn.Linear(256, 321, dtype=dtype, device=device) + quantize_(dummy, Int4WeightOnlyConfig()) + # make sure these run without error + _ = dummy.weight.narrow(0, 0, 64) + _ = dummy.weight.narrow(1, 0, 128) + common_utils.instantiate_parametrized_tests(TestAffineQuantized) common_utils.instantiate_parametrized_tests(TestAffineQuantizedBasic) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 758032e4b0..0ae95ec50b 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -284,7 +284,9 @@ def from_hp_to_intx( ) # Note: output will be uint8 tensor for sub byte tensors for now - data = _layout.post_process(data) + data, scale, zero_point = _layout.post_process( + data, scale, zero_point, block_size + ) tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout) return cls( @@ -335,7 +337,7 @@ def from_hp_to_intx_static( zero_point_domain, ) - int_data = _layout.post_process(int_data) + int_data, scale, zero_point = _layout.post_process(int_data, scale, zero_point) tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) tensor_impl = tensor_impl_ctr(int_data, scale, zero_point, _layout) @@ -429,7 +431,9 @@ def from_hp_to_fpx( # Note: these ops are hardcoded to have per axis quantization (axis=1) right now scale = choose_qparams_affine_floatx(input_float, ebits, mbits) floatx_unpacked = quantize_affine_floatx(input_float, scale, ebits, mbits) - floatx_packed = _layout.post_process(floatx_unpacked) + floatx_packed, scale, _ = _layout.post_process( + floatx_unpacked, scale, None, block_size + ) tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout) diff --git a/torchao/dtypes/uintx/gemlite_layout.py b/torchao/dtypes/uintx/gemlite_layout.py index 0370006bca..b91ac8ee4f 100644 --- a/torchao/dtypes/uintx/gemlite_layout.py +++ b/torchao/dtypes/uintx/gemlite_layout.py @@ -297,14 +297,31 @@ def __torch_dispatch__(cls, func, types, args, kwargs): if func is aten.slice.Tensor: self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) if dim == 0: + assert step == 1, "Only step == 1 is supported in slicing right now" int_data, scale, zero_point = self.get_plain() + data_len = int_data.shape[dim] + param_dim = 1 - dim + scale_len = scale.shape[param_dim] + ratio = data_len / scale_len + start_scale = int(start / ratio) + end_scale = int(end / ratio) + int_data = aten.slice.Tensor(int_data, dim, start, end, step) - int_data = self._layout.post_process(int_data) + scale = aten.slice.Tensor( + scale, param_dim, start_scale, end_scale, step + ) + if zero_point is not None and zero_point.numel() > 0: + zero_point = aten.slice.Tensor( + zero_point, param_dim, start_scale, end_scale, step + ) + else: + zero_point = None + sliced = self.from_plain(int_data, scale, zero_point, self._layout) return return_and_correct_aliasing(func, args, kwargs, sliced) elif dim == 1: - int_data, scale, zero_point = self.get_plain() assert step == 1, "Only step == 1 is supported in slicing right now" + int_data, scale, zero_point = self.get_plain() data_len = int_data.shape[dim] # scale and zero_point are transposed compared to int_data param_dim = 1 - dim @@ -314,7 +331,6 @@ def __torch_dispatch__(cls, func, types, args, kwargs): end_scale = int(end / ratio) int_data = aten.slice.Tensor(int_data, dim, start, end, step) - # this is to handle padding scale = aten.slice.Tensor( scale, param_dim, start_scale, end_scale, step ) @@ -324,9 +340,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) else: zero_point = None - # import fbvscode; fbvscode.set_trace() sliced = self.from_plain(int_data, scale, zero_point, self._layout) - return sliced + return return_and_correct_aliasing(func, args, kwargs, sliced) else: raise NotImplementedError( f"GemliteAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" diff --git a/torchao/dtypes/uintx/int4_cpu_layout.py b/torchao/dtypes/uintx/int4_cpu_layout.py index f37e5a0684..4ccfb11d23 100644 --- a/torchao/dtypes/uintx/int4_cpu_layout.py +++ b/torchao/dtypes/uintx/int4_cpu_layout.py @@ -192,16 +192,9 @@ def __torch_dispatch__(cls, func, types, args, kwargs): if func is aten.slice.Tensor: self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) - if dim == 0: - int_data, scale, zero_point = self.get_plain() - int_data = aten.slice.Tensor(int_data, dim, start, end, step) - # this is to handle padding - int_data = self._layout.post_process(int_data) - sliced = self.from_plain(int_data, scale, zero_point, self._layout) - return return_and_correct_aliasing(func, args, kwargs, sliced) - elif dim == 1: - int_data, scale, zero_point = self.get_plain() + if dim in [0, 1]: assert step == 1, "Only step == 1 is supported in slicing right now" + int_data, scale, zero_point = self.get_plain() data_len = int_data.shape[dim] scale_len = scale.shape[dim] ratio = data_len / scale_len @@ -209,14 +202,16 @@ def __torch_dispatch__(cls, func, types, args, kwargs): end_scale = int(end / ratio) int_data = aten.slice.Tensor(int_data, dim, start, end, step) - # this is to handle padding - int_data = self._layout.post_process(int_data) scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step) zero_point = aten.slice.Tensor( zero_point, dim, start_scale, end_scale, step ) + # this is to handle padding + int_data, scale, zero_point = self._layout.post_process( + int_data, scale, zero_point, self.block_size + ) sliced = self.from_plain(int_data, scale, zero_point, self._layout) - return sliced + return return_and_correct_aliasing(func, args, kwargs, sliced) else: raise NotImplementedError( f"Int4CPUAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" @@ -228,6 +223,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs): __torch_function__ = torch._C._disabled_torch_function_impl + @property + def block_size(self): + from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros + + scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) + cur_shape = self.shape + assert len(cur_shape) == 4 + inner_k_tiles = cur_shape[-1] * 2 + original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16)) + groupsize = int(original_shape[1] / scale.shape[-2]) + return (1, groupsize) + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: from torchao.quantization.quant_primitives import ( ZeroPointDomain, diff --git a/torchao/dtypes/uintx/marlin_qqq_tensor.py b/torchao/dtypes/uintx/marlin_qqq_tensor.py index 3c87714d86..0c4301ce93 100644 --- a/torchao/dtypes/uintx/marlin_qqq_tensor.py +++ b/torchao/dtypes/uintx/marlin_qqq_tensor.py @@ -72,7 +72,6 @@ def from_hp_to_intx( data, s_group, s_channel, _ = choose_qparams_and_quantize_affine_qqq( input_float, nbits, group_size ) - data = _layout.post_process(data) tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) tensor_impl = tensor_impl_ctr(data, s_group, s_channel, _layout) return cls( diff --git a/torchao/dtypes/uintx/packed_linear_int8_dynamic_activation_intx_weight_layout.py b/torchao/dtypes/uintx/packed_linear_int8_dynamic_activation_intx_weight_layout.py index 6dcf9e3a10..663f47786a 100644 --- a/torchao/dtypes/uintx/packed_linear_int8_dynamic_activation_intx_weight_layout.py +++ b/torchao/dtypes/uintx/packed_linear_int8_dynamic_activation_intx_weight_layout.py @@ -429,7 +429,6 @@ def from_hp_to_intx( ) # Note: output will be uint8 tensor for sub byte tensors for now - data = _layout.post_process(data) tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) tensor_impl = tensor_impl_ctr( data, scale, zero_point, _layout, **(tensor_impl_ctr_kwargs or {}) diff --git a/torchao/dtypes/uintx/tensor_core_tiled_layout.py b/torchao/dtypes/uintx/tensor_core_tiled_layout.py index e08d3fa2bb..901c4c4640 100644 --- a/torchao/dtypes/uintx/tensor_core_tiled_layout.py +++ b/torchao/dtypes/uintx/tensor_core_tiled_layout.py @@ -153,7 +153,13 @@ def pre_process_static( zero_point = torch.nn.functional.pad(zero_point, padding_changes) return input, scale, zero_point - def post_process(self, input: torch.Tensor) -> torch.Tensor: + def post_process( + self, + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + block_size: Tuple[int, ...], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: orig_out_features, orig_in_features = input.shape in_features = find_multiple(orig_in_features, 1024) out_features = find_multiple(orig_out_features, 8) @@ -161,7 +167,16 @@ def post_process(self, input: torch.Tensor) -> torch.Tensor: input, (0, in_features - orig_in_features, 0, out_features - orig_out_features), ) - return input + assert ( + len(block_size) == 2 + ), f"TensorCoreTiledLayout only supports len(block_size) == 2, got: {block_size}" + scale_pad_dim_0 = (out_features - orig_out_features) // block_size[0] + scale_pad_dim_1 = (in_features - orig_in_features) // block_size[1] + scale = torch.nn.functional.pad(scale, (0, scale_pad_dim_1, 0, scale_pad_dim_0)) + zero_point = torch.nn.functional.pad( + zero_point, (0, scale_pad_dim_1, 0, scale_pad_dim_0) + ) + return input, scale, zero_point def extra_repr(self): return f"inner_k_tiles={self.inner_k_tiles}" @@ -335,16 +350,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs): if func is aten.slice.Tensor: self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) - if dim == 0: - int_data, scale, zero_point = self.get_plain() - int_data = aten.slice.Tensor(int_data, dim, start, end, step) - # this is to handle padding - int_data = self._layout.post_process(int_data) - sliced = self.from_plain(int_data, scale, zero_point, self._layout) - return return_and_correct_aliasing(func, args, kwargs, sliced) - elif dim == 1: + if dim in [0, 1]: int_data, scale, zero_point = self.get_plain() - assert step == 1, "Only step == 1 is supported in slicing right now" data_len = int_data.shape[dim] scale_len = scale.shape[dim] ratio = data_len / scale_len @@ -352,14 +359,16 @@ def __torch_dispatch__(cls, func, types, args, kwargs): end_scale = int(end / ratio) int_data = aten.slice.Tensor(int_data, dim, start, end, step) - # this is to handle padding - int_data = self._layout.post_process(int_data) scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step) zero_point = aten.slice.Tensor( zero_point, dim, start_scale, end_scale, step ) + # this is to handle padding + int_data, scale, zero_point = self._layout.post_process( + int_data, scale, zero_point, self.block_size + ) sliced = self.from_plain(int_data, scale, zero_point, self._layout) - return sliced + return return_and_correct_aliasing(func, args, kwargs, sliced) else: raise NotImplementedError( f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" @@ -371,6 +380,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs): __torch_function__ = torch._C._disabled_torch_function_impl + @property + def block_size(self): + from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros + + scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) + cur_shape = self.shape + assert len(cur_shape) == 4 + inner_k_tiles = cur_shape[-1] * 2 + original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16)) + groupsize = int(original_shape[1] / scale.shape[-2]) + return (1, groupsize) + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: from torchao.quantization.quant_primitives import ( ZeroPointDomain, diff --git a/torchao/dtypes/uintx/uintx_layout.py b/torchao/dtypes/uintx/uintx_layout.py index a783e62a44..f41164ca08 100644 --- a/torchao/dtypes/uintx/uintx_layout.py +++ b/torchao/dtypes/uintx/uintx_layout.py @@ -228,8 +228,14 @@ class UintxLayout(Layout): dtype: torch.dtype pack_dim: int = -1 - def post_process(self, input: torch.Tensor) -> torch.Tensor: - return to_uintx(input, self.dtype, self.pack_dim) + def post_process( + self, + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + block_size: Tuple[int, ...], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return to_uintx(input, self.dtype, self.pack_dim), scale, zero_point @register_layout(UintxLayout) diff --git a/torchao/dtypes/utils.py b/torchao/dtypes/utils.py index beec9f0bb7..a07188a18d 100644 --- a/torchao/dtypes/utils.py +++ b/torchao/dtypes/utils.py @@ -44,8 +44,14 @@ class Layout: def pre_process(self, input: torch.Tensor) -> torch.Tensor: return input - def post_process(self, input: torch.Tensor) -> torch.Tensor: - return input + def post_process( + self, + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + block_size: Tuple[int, ...], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return input, scale, zero_point def pre_process_static( self, diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index d778f5b950..da6512468a 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -90,6 +90,20 @@ def wrapper(*args, **kwargs): return decorator +def skip_if_no_cuda(): + import unittest + + def decorator(test_func): + def wrapper(*args, **kwargs): + if not torch.cuda.is_available(): + raise unittest.SkipTest("No cuda available") + return test_func(*args, **kwargs) + + return wrapper + + return decorator + + # copied from https://github.com/pytorch/pytorch/blob/941d094dd1b507dacf06ddc6ed3485a9537e09b7/test/inductor/test_torchinductor.py#L11389 def copy_tests(my_cls, other_cls, suffix, test_failures=None, xfail_prop=None): # noqa: B902 for name, value in my_cls.__dict__.items(): From ae5fa0e972d1ee53cd271846e785f9e60c6fda00 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Mon, 7 Apr 2025 13:38:06 -0400 Subject: [PATCH 27/30] Fix Int4WeightEmbeddingQATQuantizer.convert path (#2024) **Summary:** Fixes the issue where `Int4WeightEmbeddingQATQuantizer`'s convert path assigned the scales and zero points to the wrong attributes ("scales" and "zeros" instead of "scale" and "zero point"), and also ensures the precisions are correctly set. **Test Plan:** python test/quantization/test_qat.py -k test_qat_4w_embedding --- test/quantization/test_qat.py | 28 ++++++++++++++++++++++++++- torchao/quantization/qat/embedding.py | 4 ++-- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index fcd4969bbf..4e267b7124 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -788,17 +788,43 @@ def test_composable_qat_quantizer(self): not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" ) def test_qat_4w_embedding(self): + from torchao._executorch_ops import ( + _quantized_decomposed_quantize_per_channel_group_wrapper, + ) from torchao.quantization.qat import Int4WeightOnlyEmbeddingQATQuantizer + group_size = 256 model = M2() x = model.example_inputs() model(*x) - quantizer = Int4WeightOnlyEmbeddingQATQuantizer() + quantizer = Int4WeightOnlyEmbeddingQATQuantizer(group_size) prepared = quantizer.prepare(model) + prepared_embedding_weight = copy.deepcopy(prepared.embedding.weight) prepared(*x) converted = quantizer.convert(model) converted(*x) + # Assert the scales, zero points, and weights are correct after convert + qmin, qmax = -8, 7 + (s, zp) = get_group_qparams_symmetric( + prepared_embedding_weight, + 4, + group_size, + ) + zp = zp.to(torch.int32) + q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper( + prepared_embedding_weight, + s, + zp, + qmin, + qmax, + torch.int8, + group_size, + ) + torch.testing.assert_close(converted.embedding.weight, q_weight) + torch.testing.assert_close(converted.embedding.scale, s) + torch.testing.assert_close(converted.embedding.zero_point, zp) + def test_fake_quantize_config_granularity(self): """ Test initialization and property setting of `FakeQuantizeConfig`'s granularity. diff --git a/torchao/quantization/qat/embedding.py b/torchao/quantization/qat/embedding.py index cc63c5181d..42e9b08eed 100644 --- a/torchao/quantization/qat/embedding.py +++ b/torchao/quantization/qat/embedding.py @@ -245,8 +245,8 @@ def _convert_helper(self, module: torch.nn.Module): group_size, ) quantized_embedding.weight = q_weight - quantized_embedding.scales = s - quantized_embedding.zeros = zp + quantized_embedding.scale = s.to(scale_precision) + quantized_embedding.zero_point = zp.to(zero_point_precision) else: self._convert_helper(child) From 061fae49687efff23e4844d89952b2c3c9ae5bcf Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 8 Apr 2025 11:06:32 -0700 Subject: [PATCH 28/30] Add gguf q4_k quantization (#2001) * Add gguf q4_k_s quantization Summary: Didn't implement the algorithm to choose_qparams from gguf, since it's complicated, e.g. https://github.com/ggml-org/llama.cpp/blob/f423981ac806bf031d83784bcb47d2721bc70f97/ggml/src/ggml-quants.c#L744 and https://github.com/ggml-org/llama.cpp/blob/f423981ac806bf031d83784bcb47d2721bc70f97/ggml/src/ggml-quants.c#L827C14-L827C28 but implemented a simple choose_qparams that can fit the gguf format: Q4_K: w = q * block_scale(6-bit) + block_min(6-bit) Test Plan: python test/prototype/test_gguf_quant.py Reviewers: Subscribers: Tasks: Tags: * fix * test with phi4 * pre-commit run * update * run precommit * format --- test/prototype/test_gguf_quant.py | 59 ++++ torchao/core/config.py | 6 +- torchao/prototype/quantization/__init__.py | 5 + .../prototype/quantization/gguf/__init__.py | 9 + torchao/prototype/quantization/gguf/api.py | 52 ++++ .../gguf/gguf_quantized_tensor.py | 272 ++++++++++++++++++ torchao/quantization/quant_primitives.py | 213 ++++++++++++++ 7 files changed, 615 insertions(+), 1 deletion(-) create mode 100644 test/prototype/test_gguf_quant.py create mode 100644 torchao/prototype/quantization/gguf/__init__.py create mode 100644 torchao/prototype/quantization/gguf/api.py create mode 100644 torchao/prototype/quantization/gguf/gguf_quantized_tensor.py diff --git a/test/prototype/test_gguf_quant.py b/test/prototype/test_gguf_quant.py new file mode 100644 index 0000000000..b68d84b101 --- /dev/null +++ b/test/prototype/test_gguf_quant.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch + +from torchao.prototype.quantization.gguf import ( + GGUFQuantizedTensor, + GGUFWeightOnlyConfig, +) +from torchao.quantization import quantize_ +from torchao.quantization.quant_primitives import choose_qparams_gguf +from torchao.quantization.utils import compute_error + + +class TestGGUFQuantization(unittest.TestCase): + def setUp(self): + torch.manual_seed(123) + self.input = torch.randn(2, 256, dtype=torch.float32) + self.n_blocks_per_superblock = 8 + self.block_size = (1, 32) + self.dtype = torch.uint4 + + def test_choose_qparams_gguf(self): + ( + super_block_scale_scale, + super_block_min_scale, + quantized_block_scale, + quantized_block_min, + ) = choose_qparams_gguf(self.input, self.block_size, self.dtype) + + assert super_block_scale_scale.shape, (2, 8) + assert super_block_min_scale.shape, (2, 8) + assert quantized_block_scale.shape, (2, 32) + + def test_gguf_quantized_tensor_from_float(self): + gqt = GGUFQuantizedTensor.from_float( + self.input, + self.n_blocks_per_superblock, + self.dtype, + ) + + dequant = gqt.dequantize() + + sqnr = compute_error(dequant, self.input) + self.assertGreater(sqnr, 30) + + def test_quantize_api(self): + m = torch.nn.Sequential(torch.nn.Linear(256, 64)) + quantize_(m, GGUFWeightOnlyConfig()) + assert type(m[0].weight) == GGUFQuantizedTensor + + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/core/config.py b/torchao/core/config.py index 4a5a4c5720..fe03ac225b 100644 --- a/torchao/core/config.py +++ b/torchao/core/config.py @@ -171,7 +171,11 @@ def config_to_dict(config: AOBaseConfig) -> Dict[str, Any]: return json.loads(json.dumps(config, cls=ConfigJSONEncoder)) -ALLOWED_AO_MODULES = {"torchao.quantization", "torchao.sparsity.sparse_api"} +ALLOWED_AO_MODULES = { + "torchao.quantization", + "torchao.sparsity.sparse_api", + "torchao.prototype.quantization", +} def config_from_dict(data: Dict[str, Any]) -> AOBaseConfig: diff --git a/torchao/prototype/quantization/__init__.py b/torchao/prototype/quantization/__init__.py index e69de29bb2..bf49e2717b 100644 --- a/torchao/prototype/quantization/__init__.py +++ b/torchao/prototype/quantization/__init__.py @@ -0,0 +1,5 @@ +from .gguf import GGUFWeightOnlyConfig + +__all__ = [ + "GGUFWeightOnlyConfig", +] diff --git a/torchao/prototype/quantization/gguf/__init__.py b/torchao/prototype/quantization/gguf/__init__.py new file mode 100644 index 0000000000..3e43e1f3dc --- /dev/null +++ b/torchao/prototype/quantization/gguf/__init__.py @@ -0,0 +1,9 @@ +from .api import GGUFWeightOnlyConfig +from .gguf_quantized_tensor import ( + GGUFQuantizedTensor, +) + +__all__ = [ + "GGUFQuantizedTensor", + "GGUFWeightOnlyConfig", +] diff --git a/torchao/prototype/quantization/gguf/api.py b/torchao/prototype/quantization/gguf/api.py new file mode 100644 index 0000000000..bc4b46992a --- /dev/null +++ b/torchao/prototype/quantization/gguf/api.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass + +import torch + +from torchao.core.config import AOBaseConfig +from torchao.quantization.transform_module import register_quantize_module_handler + +from .gguf_quantized_tensor import GGUFQuantizedTensor + +__all__ = [ + "GGUFWeightOnlyConfig", +] + + +@dataclass +class GGUFWeightOnlyConfig(AOBaseConfig): + dtype: torch.dtype = torch.uint4 + n_blocks_per_superblock: int = 8 + + +@register_quantize_module_handler(GGUFWeightOnlyConfig) +def _gguf_weight_only_transform( + module: torch.nn.Module, + config: GGUFWeightOnlyConfig, +): + """ + Applies gguf weight-only quantization to linear layers. + + Args: + dtype: torch.uint1 to torch.uint8, torch.int32 supported. + n_blocks_per_superblock: the number of super blocks in a 256 element block for gguf, e.g. when it is 8 + it means we have blocks of 32 and 8 blocks in a superblock of 256 elements. + Returns: + Callable for quantization transformation. + """ + weight = module.weight + if (weight.ndim != 2) or (weight.shape[-1] % 256 != 0): + return module + + quantized_weight = GGUFQuantizedTensor.from_float( + weight, + n_blocks_per_superblock=config.n_blocks_per_superblock, + target_dtype=config.dtype, + ) + module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) + return module diff --git a/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py b/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py new file mode 100644 index 0000000000..0bb7b9a623 --- /dev/null +++ b/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py @@ -0,0 +1,272 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from torch.utils._python_dispatch import return_and_correct_aliasing + +from torchao.quantization.quant_primitives import ( + choose_qparams_gguf, + dequantize_gguf, + quantize_gguf, +) +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + TorchAOBaseTensor, +) + +_QK_K = 256 +aten = torch.ops.aten + +__all__ = [ + "GGUFQuantizedTensor", +] + + +class GGUFQuantizedTensor(TorchAOBaseTensor): + """ + A Tensor subclass that when applied to a weight used in a linear op/module, + changes that linear op to a weight-only int4 quantized linear op with groupwise + affine quantization on the weight. + """ + + @staticmethod + def __new__( + cls, + n_blocks_per_superblock, + super_block_scale_scale, + super_block_min_scale, + quantized_block_scale, + quantized_block_min, + int_data, + shape, + **kwargs, + ): + kwargs["device"] = kwargs.get("device", super_block_scale_scale.device) + kwargs["dtype"] = kwargs.get("dtype", super_block_scale_scale.dtype) + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + n_blocks_per_superblock, + super_block_scale_scale, + super_block_min_scale, + quantized_block_scale, + quantized_block_min, + int_data, + shape, + **kwargs, + ): + self.n_blocks_per_superblock = n_blocks_per_superblock + self.super_block_scale_scale = super_block_scale_scale + self.super_block_min_scale = super_block_min_scale + self.quantized_block_scale = quantized_block_scale + self.quantized_block_min = quantized_block_min + self.int_data = int_data + + def _apply_fn_to_data(self, fn): + return self.__class__( + self.n_blocks_per_superblock, + fn(self.super_block_scale_scale), + fn(self.super_block_min_sclae), + fn(self.quantized_block_scale), + fn(self.quantized_block_min), + fn(self.int_data), + self.shape, + dtype=self.dtype, + ) + + def __tensor_flatten__(self): + return [ + "super_block_scale_scale", + "super_block_min_scale", + "quantized_block_scale", + "quantized_block_min", + "int_data", + ], ( + self.n_blocks_per_superblock, + self.dtype, + self.shape, + ) + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, attributes, outer_size=None, outer_stride=None + ): + ( + super_block_scale_scale, + super_block_min_scale, + quantized_block_scale, + quantized_block_min, + int_data, + ) = ( + tensor_data_dict["super_block_scale_scale"], + tensor_data_dict["super_block_min_scale"], + tensor_data_dict["quantized_block_scale"], + tensor_data_dict["quantized_block_min"], + tensor_data_dict["int_data"], + ) + n_blocks_per_superblock, dtype, shape = attributes + return cls( + n_blocks_per_superblock, + super_block_scale_scale, + super_block_min_scale, + quantized_block_scale, + quantized_block_min, + int_data, + shape if outer_size is None else outer_size, + dtype=dtype, + ) + + def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: + if output_dtype is None: + output_dtype = self.dtype + + block_size = tuple( + [1] * (self.int_data.ndim - 1) + [_QK_K // self.n_blocks_per_superblock] + ) + return dequantize_gguf( + self.int_data, + block_size, + self.dtype, + self.super_block_scale_scale, + self.super_block_min_scale, + self.quantized_block_scale, + self.quantized_block_min, + output_dtype=output_dtype, + ) + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + device = kwargs.pop("device") + return self.__class__( + self.n_blocks_per_superblock, + self.super_block_scale_scale.to(device), + self.super_block_min_scale.to(device), + self.quantized_block_scale.to(device), + self.quantized_block_min.to(device), + self.int_data.to(device), + self.shape, + **kwargs, + ) + + def _apply_fn_to_data(self, fn): + """ + Returns a new `CodebookQuantizedTensor`. + """ + return self.__class__( + self.n_blocks_per_superblock, + fn(self.super_block_scale_scale), + fn(self.super_block_min_scale), + fn(self.quantized_block_scale), + fn(self.quantized_block_min), + fn(self.int_data), + self.shape, + dtype=self.dtype, + ) + + def requires_grad_(self, requires_grad=False): + """ + Modifies the tensor's `requires_grad` status in-place. + """ + assert not requires_grad, "Only requires_grad == False is supported" + return self + + @classmethod + def from_float(cls, input_float, n_blocks_per_superblock, target_dtype): + """ + Method used to convert a linear weight tensor to an instance of the + GGMLInt4LinearWeight subclass. + + Example usage:: + + model.lin_mod.weight = ( + GGMLInt4LinearWeight.from_float(model.lin_mod.weight) + ) + """ + assert ( + target_dtype == torch.uint4 + ), "only uint4 quantization is supported right now" + block_size = (1, _QK_K // n_blocks_per_superblock) + ( + super_block_scale_scale, + super_block_min_scale, + quantized_block_scale, + quantized_block_min, + ) = choose_qparams_gguf(input_float, block_size, target_dtype) + + int_data = quantize_gguf( + input_float, + block_size, + target_dtype, + super_block_scale_scale, + super_block_min_scale, + quantized_block_scale, + quantized_block_min, + ) + return cls( + n_blocks_per_superblock, + super_block_scale_scale, + super_block_min_scale, + quantized_block_scale, + quantized_block_min, + int_data, + input_float.shape, + ) + + +implements = GGUFQuantizedTensor.implements + + +@implements([aten.detach.default, aten.alias.default]) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + +@implements(aten.clone.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + +@implements(aten._to_copy.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), + ) + + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + if not input_tensor.is_floating_point(): + raise NotImplementedError( + f"{func} is not implemented for non floating point input" + ) + + dtype = input_tensor.dtype + + if hasattr(weight_tensor, "dequantize"): + weight_tensor = weight_tensor.dequantize(output_dtype=dtype) + + return torch.nn.functional.linear(input_tensor, weight_tensor, bias) + + +if TORCH_VERSION_AT_LEAST_2_5: + # Allow a model with GGUFQuantizedTensor weights to be loaded with `weights_only=True` + torch.serialization.add_safe_globals([GGUFQuantizedTensor]) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 05be8c5c30..bc176c9d17 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -42,6 +42,9 @@ "choose_qparams_affine_float8", "quantize_affine_float8", "dequantize_affine_float8", + "choose_qparams_gguf", + "quantize_gguf", + "dequantize_gguf", ] @@ -195,6 +198,8 @@ class TorchAODType(Enum): _DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_INT_BOUNDS) assert _DTYPE_TO_BIT_WIDTH.keys() == _DTYPE_TO_QVALUE_BOUNDS.keys() +_GGUF_QK_K = 256 + _ONES_TABLE = [_n_ones(i) for i in range(8)] quant_lib = torch.library.Library("quant", "FRAGMENT") @@ -1039,6 +1044,214 @@ def reshape_w(w): return q_w, s_group, s_channel, w_ref +def choose_qparams_gguf( + input: Optional[torch.Tensor], + block_size: List[int], + target_dtype: torch.dtype, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + There are two sets of qparams: quantized_block_scale, quantized_block_min and super_block_scale_scale and super_block_min_scale + the relationship is the following: + block_scale = quantized_block_scale * super_block_sclae + block_min = quantized_block_min * super_block_min + quantized_val = (float_val - block_min) / block_scale + quant_min + first we calculate block_scale and block_min + then we calculate super_block_scale_scale and super_block_min_scale + after that we can calculate quantized_block_scale and quantized_min_scale + the returned values are: super_block_scale_scale, super_block_min_scale, quantized_block_scale + and quantized_min_scale + """ + dtype = input.dtype + + # 1. get block_scale block_min + shape_for_reduction, reduction_dims = _get_reduction_params( + block_size, input.size() + ) + input = input.view(shape_for_reduction) + min_val = torch.amin(input, dim=reduction_dims, keepdim=False) + max_val = torch.amax(input, dim=reduction_dims, keepdim=False) + quant_max = 15 + quant_min = 0 + # asymmetric quant to fully utilize the range + block_scale = max_val / (float(quant_max - quant_min) / 2) + block_scale = (max_val - min_val) / float(quant_max - quant_min) + block_min = min_val + + # 2. get super_block_scale_scale and super_block_min_scale + assert _GGUF_QK_K % block_size[-1] == 0 + super_block_size = (1, _GGUF_QK_K // block_size[-1]) + shape_for_reduction, reduction_dims = _get_reduction_params( + super_block_size, block_scale.size() + ) + block_scale = block_scale.view(shape_for_reduction) + block_min = block_min.view(shape_for_reduction) + + shape_after_reduction = shape_for_reduction.copy() + for i in reduction_dims: + shape_after_reduction[i] = 1 + + block_scale_absmax = torch.amax( + torch.abs(block_scale), dim=reduction_dims, keepdim=False + ) + block_min_absmax = torch.amax( + torch.abs(block_min), dim=reduction_dims, keepdim=False + ) + + # 2. get super_block_scale_scale and super_block_min_scale + # TODO: make this configurable + # we also quantize the quantization parameters (scale and min) for each block to 6 bit + # for Q4_K + qparam_quant_max = 2**6 - 1 + qparam_quant_min = 0 + super_block_scale_scale = block_scale_absmax / float( + qparam_quant_max - qparam_quant_min + ) + super_block_min_scale = block_min_absmax / float( + qparam_quant_max - qparam_quant_min + ) + super_block_scale_scale_view = super_block_scale_scale.view(shape_after_reduction) + super_block_min_scale_view = super_block_min_scale.view(shape_after_reduction) + + # 3. quantize block scale and min are stored in 6 bits using super_block_scale_scale and super_block_min_scale + quantized_block_scale = torch.clamp( + block_scale / super_block_scale_scale_view, qparam_quant_min, qparam_quant_max + ) + quantized_block_min = torch.clamp( + block_min / super_block_min_scale_view, qparam_quant_min, qparam_quant_max + ) + return ( + super_block_scale_scale.to(dtype), + super_block_min_scale.to(dtype), + quantized_block_scale.to(dtype), + quantized_block_min.to(dtype), + ) + + +def quantize_gguf( + input: torch.Tensor, + block_size: List[int], + target_dtype: torch.dtype, + super_block_scale_scale: torch.Tensor, + super_block_min_scale: torch.Tensor, + quantized_block_scale: torch.Tensor, + quantized_block_min: torch.Tensor, +) -> torch.Tensor: + assert target_dtype == torch.uint4 + + # step 1: first order quantization + # just going through shape calculation for block_scale and block_min to get the correct shape + input_shape_for_reduction, reduction_dims = _get_reduction_params( + block_size, input.size() + ) + block_qparam_shape_after_reduction = input_shape_for_reduction.copy() + for i in reduction_dims: + block_qparam_shape_after_reduction[i] = 1 + original_shape = input.shape + input = input.view(input_shape_for_reduction) + quantized_block_scale = quantized_block_scale.view( + block_qparam_shape_after_reduction + ) + quantized_block_min = quantized_block_min.view(block_qparam_shape_after_reduction) + + # step 2: second order quantization, recover unquantized block_scale and block_min + super_block_size = (1, _GGUF_QK_K // block_size[-1], 1) + super_block_input_shape_for_reduction, reduction_dims = _get_reduction_params( + super_block_size, quantized_block_scale.size() + ) + super_block_qparam_shape_after_reduction = ( + super_block_input_shape_for_reduction.copy() + ) + for i in reduction_dims: + super_block_qparam_shape_after_reduction[i] = 1 + + quantized_block_scale = quantized_block_scale.view( + super_block_input_shape_for_reduction + ) + quantized_block_min = quantized_block_min.view( + super_block_input_shape_for_reduction + ) + super_block_scale_scale = super_block_scale_scale.view( + super_block_qparam_shape_after_reduction + ) + super_block_min_scale = super_block_min_scale.view( + super_block_qparam_shape_after_reduction + ) + + block_scale = super_block_scale_scale * quantized_block_scale + block_min = super_block_min_scale * quantized_block_min + + # step 3: quantization with the unquantized block_scale and block_min + block_scale = block_scale.view(block_qparam_shape_after_reduction) + block_min = block_min.view(block_qparam_shape_after_reduction) + int_data = (input - block_min) / block_scale + int_data = int_data.view(original_shape) + + return int_data + + +def dequantize_gguf( + input: torch.Tensor, + block_size: List[int], + target_dtype: torch.dtype, + super_block_scale_scale: torch.Tensor, + super_block_min_scale: torch.Tensor, + quantized_block_scale: torch.Tensor, + quantized_block_min: torch.Tensor, + output_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + # step 1. reshape input and quantized block scale and min to the shape + # after first quantization + input_shape_for_reduction, reduction_dims = _get_reduction_params( + block_size, input.size() + ) + block_qparam_shape_after_reduction = input_shape_for_reduction.copy() + for i in reduction_dims: + block_qparam_shape_after_reduction[i] = 1 + + original_shape = input.shape + input = input.view(input_shape_for_reduction) + quantized_block_scale = quantized_block_scale.view( + block_qparam_shape_after_reduction + ) + quantized_block_min = quantized_block_min.view(block_qparam_shape_after_reduction) + + # step 2. calculate and reshape block_qparams for second quantization step + super_block_size = (1, _GGUF_QK_K // block_size[-1], 1) + super_block_input_shape_for_reduction, reduction_dims = _get_reduction_params( + super_block_size, quantized_block_scale.size() + ) + super_block_qparam_shape_after_reduction = ( + super_block_input_shape_for_reduction.copy() + ) + for i in reduction_dims: + super_block_qparam_shape_after_reduction[i] = 1 + quantized_block_scale = quantized_block_scale.view( + super_block_input_shape_for_reduction + ) + quantized_block_min = quantized_block_min.view( + super_block_input_shape_for_reduction + ) + super_block_scale_scale = super_block_scale_scale.view( + super_block_qparam_shape_after_reduction + ) + super_block_min_scale = super_block_min_scale.view( + super_block_qparam_shape_after_reduction + ) + + block_scale = super_block_scale_scale * quantized_block_scale + block_min = super_block_min_scale * quantized_block_min + + # step 3. dequantize with block_scale and block_min + block_scale = block_scale.view(block_qparam_shape_after_reduction) + block_min = block_min.view(block_qparam_shape_after_reduction) + dequant = input * block_scale + block_min + dequant = dequant.view(original_shape) + if output_dtype is not None: + dequant = dequant.to(output_dtype) + + return dequant + + def dequantize_affine_qqq( w: torch.Tensor, s_group: torch.Tensor, From 4b8a0d8c3d351a156d99b90c5aec86758bd01a8c Mon Sep 17 00:00:00 2001 From: gmagogsfm Date: Tue, 8 Apr 2025 13:03:59 -0700 Subject: [PATCH 29/30] torch/ao Differential Revision: D72435827 Pull Request resolved: https://github.com/pytorch/ao/pull/2029 --- test/integration/test_integration.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 38f8f0341c..9bbf625fc4 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -2029,7 +2029,9 @@ def forward(self, x): # we can re-enable this after non-functional IR is enabled in export # model = torch.export.export(model, example_inputs).module() if TORCH_VERSION_AT_LEAST_2_5: - model = torch.export.export_for_training(model, example_inputs).module() + model = torch.export.export_for_training( + model, example_inputs, strict=True + ).module() else: model = torch._export.capture_pre_autograd_graph(model, example_inputs) after_export = model(x) From da111e41de106ca84d3868365946b26910b7f124 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Tue, 8 Apr 2025 13:41:02 -0700 Subject: [PATCH 30/30] Adds Q/DQ layout support for embedding quantization with IntxWeightOnlyConfig (#1972) * up * up * up * up * up * up * up * up --- torchao/dtypes/affine_quantized_tensor_ops.py | 9 ++ torchao/dtypes/uintx/q_dq_layout.py | 13 +++ torchao/experimental/quant_api.py | 87 ++++++++------- torchao/experimental/quant_passes.py | 98 +++++++++++++++++ .../tests/test_embedding_xbit_quantizer.py | 11 +- .../experimental/tests/test_quant_passes.py | 70 +++++++++++- torchao/quantization/quant_api.py | 101 ++++++++++++++++++ 7 files changed, 337 insertions(+), 52 deletions(-) diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index f5c2b0bf63..b4912523bb 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -64,6 +64,12 @@ _linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl, ) +from torchao.dtypes.uintx.q_dq_layout import ( + _embedding_check as _embedding_q_dq_check, +) +from torchao.dtypes.uintx.q_dq_layout import ( + _embedding_impl as _embedding_q_dq_impl, +) from torchao.dtypes.uintx.q_dq_layout import ( _linear_check as _linear_q_dq_check, ) @@ -263,6 +269,9 @@ def _(func, types, args, kwargs): @implements(torch.nn.functional.embedding) def _(func, types, args, kwargs): + if _embedding_q_dq_check(args, kwargs): + return _embedding_q_dq_impl(args, kwargs) + # new_arg1 = args[1].dequantize() # return torch.nn.embedding(args[0], new_arg1, *args[2:], **kwargs) assert isinstance( diff --git a/torchao/dtypes/uintx/q_dq_layout.py b/torchao/dtypes/uintx/q_dq_layout.py index d0a58c2e18..1d5b2048b0 100644 --- a/torchao/dtypes/uintx/q_dq_layout.py +++ b/torchao/dtypes/uintx/q_dq_layout.py @@ -50,3 +50,16 @@ def _linear_impl(input_tensor, weight_tensor, bias): if isinstance(weight_tensor, AffineQuantizedTensor): weight_tensor = weight_tensor.dequantize() return torch.nn.functional.linear(input_tensor, weight_tensor, bias) + + +def _embedding_check(args, kwargs): + _, weight_tensor = args + layout = weight_tensor.tensor_impl.get_layout() + return isinstance(layout, QDQLayout) + + +def _embedding_impl(args, kwargs): + input_tensor, weight_tensor = args + if isinstance(weight_tensor, AffineQuantizedTensor): + weight_tensor = weight_tensor.dequantize() + return torch.nn.functional.embedding(input_tensor, weight_tensor, **kwargs) diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py index 8091042738..e45a8d2bef 100644 --- a/torchao/experimental/quant_api.py +++ b/torchao/experimental/quant_api.py @@ -15,7 +15,7 @@ quantize_per_channel_group, ) -from torchao.quantization.granularity import PerGroup, PerRow +from torchao.quantization.granularity import Granularity, PerAxis, PerGroup, PerRow from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 logger = logging.getLogger(__name__) @@ -366,32 +366,44 @@ def __init__( ): super().__init__() self.bit_width = bit_width - self.pack_weights_op = getattr( - torch.ops.torchao, f"_pack_embedding_{bit_width}bit" - ) - self.embedding_op = getattr(torch.ops.torchao, f"_embedding_{bit_width}bit") def quantize_and_pack_weights(self, weights, group_size, has_weight_zeros): assert has_weight_zeros, "has_weight_zeros must be True for QuantizedEmbedding" num_embeddings, embedding_dim = weights.shape - if group_size == -1: - group_size = embedding_dim - self.group_size = group_size - weight_qvals, weight_scales, weight_zeros = _quantize( - weights, self.group_size, self.bit_width, has_weight_zeros=True + embedding = torch.nn.Embedding(num_embeddings, embedding_dim) + embedding.weight = weights + quantize_( + embedding, + IntxWeightOnlyConfig( + weight_dtype=getattr(torch, f"int{self.bit_width}"), + granularity=PerGroup(group_size) if group_size > 0 else PerAxis(0), + zero_point_domain=ZeroPointDomain.INT + if has_weight_zeros + else ZeroPointDomain.NONE, + mapping_type=MappingType.ASYMMETRIC, + ), + lambda m, fqn: isinstance(m, torch.nn.Embedding), + ) + weight_qvals, weight_scales, weight_zeros = ( + embedding.weight.tensor_impl.get_plain() ) + weight_scales = weight_scales.reshape(num_embeddings, -1) + weight_zeros = weight_zeros.reshape(num_embeddings, -1).to(torch.int8) self.register_buffer( - "packed_weight_qvals", self.pack_weights_op(weight_qvals.to(torch.int8)) + "packed_weight_qvals", + getattr(torch.ops.torchao, f"_pack_embedding_{self.bit_width}bit")( + weight_qvals.to(torch.int8) + ), ) self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim self.register_buffer("weight_scales", weight_scales) - self.register_buffer("weight_zeros", weight_zeros.to(torch.int8)) + self.register_buffer("weight_zeros", weight_zeros) def forward(self, x): shape = x.shape - return self.embedding_op( + return getattr(torch.ops.torchao, f"_embedding_{self.bit_width}bit")( self.packed_weight_qvals, self.num_embeddings, self.embedding_dim, @@ -410,38 +422,23 @@ def __init__( self.bit_width = bit_width def quantize_and_pack_weights(self, weights, group_size, has_weight_zeros): - assert ( - has_weight_zeros - ), "has_weight_zeros must be True for QuantizedEmbeddingFallback" - num_embeddings, embedding_dim = weights.shape - if group_size == -1: - group_size = embedding_dim - self.group_size = group_size - - weight_qvals, weight_scales, weight_zeros = _quantize( - weights, self.group_size, self.bit_width, has_weight_zeros=True + self.embedding = torch.nn.Embedding(*weights.shape) + self.embedding.weight = weights + quantize_( + self.embedding, + IntxWeightOnlyConfig( + weight_dtype=getattr(torch, f"int{self.bit_width}"), + granularity=PerGroup(group_size) if group_size > 0 else PerAxis(0), + zero_point_domain=ZeroPointDomain.INT + if has_weight_zeros + else ZeroPointDomain.NONE, + mapping_type=MappingType.ASYMMETRIC, + ), + lambda m, fqn: isinstance(m, torch.nn.Embedding), ) - self.weight_qvals = weight_qvals.to(torch.int32) - self.weight_scales = weight_scales - self.weight_zeros = weight_zeros.to(torch.int32) def forward(self, x): - shape = x.shape - res = [] - for i in x: - res.append( - dequantize_per_channel_group( - w_int8=self.weight_qvals[i, :].reshape(1, -1), - scales=self.weight_scales[i, :].reshape(1, -1), - zero_points=self.weight_zeros[i, :].reshape(1, -1), - quant_min=None, # TODO: why is this an arg for this function - quant_max=None, # TODO: why is this an arg for this function - dtype=None, # TODO: why is this an arg for this function - group_size=self.group_size, - output_dtype=torch.float32, - ).reshape(-1) - ) - return torch.stack(res).reshape(*shape, -1) + return self.embedding(x) class QuantizedSharedEmbedding(nn.Module): @@ -586,7 +583,7 @@ class EmbeddingQuantizer: def __init__( self, weight_dtype: torch.dtype = torch.int4, - granularity: Union[PerRow, PerGroup] = PerRow(), + granularity: Granularity = PerAxis(0), has_weight_zeros: bool = True, use_fallback: bool = False, ): @@ -594,7 +591,8 @@ def __init__( if isinstance(granularity, PerGroup): group_size = granularity.group_size - elif isinstance(granularity, PerRow): + elif isinstance(granularity, PerAxis): + assert granularity.axis == 0 group_size = -1 else: raise ValueError(f"Unsupported granularity: {granularity}") @@ -630,6 +628,7 @@ def quantize(self, model: nn.Module) -> nn.Module: to_linear_activation_quantized, ) from torchao.quantization.quant_api import ( + IntxWeightOnlyConfig, MappingType, ZeroPointDomain, to_affine_quantized_intx, diff --git a/torchao/experimental/quant_passes.py b/torchao/experimental/quant_passes.py index 9a744643c8..1b25dc1371 100644 --- a/torchao/experimental/quant_passes.py +++ b/torchao/experimental/quant_passes.py @@ -215,3 +215,101 @@ def replace_q_dq_patterns_with_quantized_linear_ops_pass( # Re-export return torch.export.export(gm, *ep.example_inputs) + + +def _get_q_dq_embedding_patterns_replacements_and_filters( + weight_bit_width, +): + w_quant_min = -(1 << (weight_bit_width - 1)) + w_quant_max = (1 << (weight_bit_width - 1)) - 1 + w_target_dtype = torch.int8 + + def pattern( + indices, + w_int_data, + w_block_size, + w_scale, + w_zero_point, + ): + dq_w = torch.ops.quant.dequantize_affine.default( + w_int_data, + w_block_size, + w_scale, + w_zero_point, + w_target_dtype, + w_quant_min, + w_quant_max, + ) + return torch.ops.aten.embedding.default(dq_w, indices) + + def replacement( + indices, + w_int_data, + w_block_size, + w_scale, + w_zero_point, + ): + num_embeddings, embedding_dim = w_int_data.size() + packed_weight_qvals = getattr( + torch.ops.torchao, f"_pack_embedding_{weight_bit_width}bit" + )(w_int_data) + out_shape = indices.shape + (embedding_dim,) + group_size = w_block_size[-1] + n_groups = embedding_dim // group_size + w_scale = w_scale.reshape(-1, n_groups) + w_zero_point = w_zero_point.reshape(-1, n_groups) + return getattr(torch.ops.torchao, f"_embedding_{weight_bit_width}bit")( + packed_weight_qvals, + num_embeddings, + embedding_dim, + w_scale, + w_zero_point, + indices.reshape(-1), + ).reshape(out_shape) + + def match_filter(match, x, y): + def get_val(name): + node = [n for n in match.nodes_map if n.name == name][0] + return match.nodes_map[node] + + # We only want w_block_size with shape [1, group_size] + w_block_size = get_val("w_block_size") + if len(w_block_size) != 2 or w_block_size[0] != 1: + return False + + return True + + return pattern, replacement, match_filter + + +def replace_q_dq_patterns_with_quantized_embedding_ops_pass( + ep: torch.export.ExportedProgram, +) -> torch.export.ExportedProgram: + """ + This replaces Q/DQ patterns with torchao quantized embedding ops. + It is intended for converting Q/DQ nodes exported with QDQLayout to using + the lowbit quantized embedding ops. + """ + # TODO: figure out how to do this with dynamic_shapes (not saved on EP for easy re-export) + # See https://fb.workplace.com/groups/1028545332188949/permalink/1185289956514485/ + assert ( + len(ep.range_constraints) == 0 + ), "ExportedProgram with range constraints are not supported" + + # ep.module() unlifts the weight inputs, which we need for constant folding + gm = ep.module() + for weight_bit_width in range(1, 9): + pattern, replacement, match_filter = ( + _get_q_dq_embedding_patterns_replacements_and_filters( + weight_bit_width, + ) + ) + subgraph_rewriter.replace_pattern_with_filters( + gm, pattern, replacement, match_filters=[match_filter] + ) + + # Constant fold evaluates and removes the packing ops + constant_fold(gm) + + # Re-export + return torch.export.export(gm, *ep.example_inputs) diff --git a/torchao/experimental/tests/test_embedding_xbit_quantizer.py b/torchao/experimental/tests/test_embedding_xbit_quantizer.py index 844c96760f..8f4afcda04 100644 --- a/torchao/experimental/tests/test_embedding_xbit_quantizer.py +++ b/torchao/experimental/tests/test_embedding_xbit_quantizer.py @@ -19,7 +19,7 @@ Int8DynamicActivationIntxWeightConfig, SharedEmbeddingQuantizer, ) -from torchao.quantization.granularity import PerGroup, PerRow +from torchao.quantization.granularity import PerAxis, PerGroup, PerRow from torchao.quantization.quant_api import quantize_ @@ -68,7 +68,7 @@ def test_accuracy(self): def test_export_compile_aoti(self): weight_dtype = torch.int4 - granularity = PerRow() + granularity = PerAxis(0) embedding_dim = 4096 num_embeddings = 131 model = torch.nn.Sequential( @@ -113,7 +113,6 @@ def test_export_compile_aoti(self): def test_shared_embedding(self): weight_dtype = torch.int4 - granularity = PerRow() has_weight_zeros = True embedding_dim = 4096 num_embeddings = 131 @@ -134,14 +133,14 @@ def test_shared_embedding(self): quantized_model_reference = copy.deepcopy(model) EmbeddingQuantizer( weight_dtype=weight_dtype, - granularity=granularity, + granularity=PerAxis(0), has_weight_zeros=has_weight_zeros, ).quantize(quantized_model_reference) quantize_( quantized_model_reference, Int8DynamicActivationIntxWeightConfig( weight_dtype=weight_dtype, - granularity=granularity, + granularity=PerRow(), has_weight_zeros=has_weight_zeros, round_weight_scale_to_bf16=False, layout=PackedLinearInt8DynamicActivationIntxWeightLayout( @@ -155,7 +154,7 @@ def test_shared_embedding(self): quantized_model = copy.deepcopy(model) SharedEmbeddingQuantizer( weight_dtype=weight_dtype, - granularity=granularity, + granularity=PerRow(), has_weight_zeros=has_weight_zeros, ).quantize(quantized_model) diff --git a/torchao/experimental/tests/test_quant_passes.py b/torchao/experimental/tests/test_quant_passes.py index 3262e2bf7b..35282f331f 100644 --- a/torchao/experimental/tests/test_quant_passes.py +++ b/torchao/experimental/tests/test_quant_passes.py @@ -7,6 +7,7 @@ import unittest import torch +from parameterized import param, parameterized from torch.testing import FileCheck from torchao.experimental.q_dq_layout import QDQLayout @@ -14,10 +15,16 @@ Int8DynamicActivationIntxWeightConfig, ) from torchao.experimental.quant_passes import ( + replace_q_dq_patterns_with_quantized_embedding_ops_pass, replace_q_dq_patterns_with_quantized_linear_ops_pass, ) -from torchao.quantization.granularity import PerGroup, PerRow -from torchao.quantization.quant_api import quantize_ +from torchao.quantization.granularity import PerAxis, PerGroup, PerRow +from torchao.quantization.quant_api import ( + IntxWeightOnlyConfig, + MappingType, + ZeroPointDomain, + quantize_, +) class TestQuantPasses(unittest.TestCase): @@ -77,6 +84,65 @@ def test_replace_q_dq_patterns_with_quantized_linear_ops_pass(self): exported_results = exported.module()(activations) self.assertTrue(torch.allclose(exported_results, eager_results)) + @parameterized.expand( + [ + param(weight_dtype=weight_dtype, granularity=granularity) + for weight_dtype in [getattr(torch, f"int{i}") for i in range(1, 9)] + for granularity in [PerAxis(0), PerGroup(32)] + ], + name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}", + ) + def test_replace_q_dq_patterns_with_quantized_embedding_ops_pass( + self, weight_dtype, granularity + ): + # Calling torch.export many times in a parametrized test causes + # torch._dynamo.exc.FailOnRecompileLimitHit: recompile_limit reached error + # Setting cache_size_limit to a large number to avoid this error + torch._dynamo.config.cache_size_limit = 10000 + + mapping_type = MappingType.ASYMMETRIC + zero_point_domain = ZeroPointDomain.INT + + model = torch.nn.Sequential( + *[torch.nn.Embedding(5000, 512), torch.nn.Linear(512, 512)] + ) + indices = torch.randint(0, 5000, (4, 5, 17), dtype=torch.int32) + + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=weight_dtype, + granularity=granularity, + zero_point_domain=zero_point_domain, + mapping_type=mapping_type, + layout=QDQLayout(), + ), + lambda m, fqn: isinstance(m, torch.nn.Embedding), + ) + eager_results = model(indices) + + exported = torch.export.export(model, (indices,), strict=True) + exported = replace_q_dq_patterns_with_quantized_embedding_ops_pass(exported) + + # We should not find pack op because it gets constant folded + FileCheck().check_not("torch.ops.torchao._pack_embedding").run( + exported.graph_module.code + ) + + # We should find + FileCheck().check_count( + "torch.ops.torchao._embedding", count=1, exactly=True + ).run(exported.graph_module.code) + + # We should not find Q/DQ ops + FileCheck().check_not("torch.ops.quant.dequantize_affine.default").run( + exported.graph_module.code + ) + + # Numerics should match + exported_results = exported.module()(indices) + self.assertTrue(torch.allclose(exported_results, eager_results)) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 9bbdd3dfbf..f325a587b4 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -36,6 +36,7 @@ MarlinQQQLayout, MarlinSparseLayout, PlainLayout, + QDQLayout, SemiSparseLayout, TensorCoreTiledLayout, UintxLayout, @@ -75,6 +76,9 @@ Int8DynActInt4WeightQuantizer, ) from .granularity import ( + Granularity, + PerAxis, + PerGroup, PerRow, PerTensor, ) @@ -86,6 +90,7 @@ intx_quantization_aware_training, ) from .quant_primitives import ( + _DTYPE_TO_QVALUE_BOUNDS, MappingType, ZeroPointDomain, ) @@ -1569,6 +1574,102 @@ def _uintx_weight_only_transform( return module +@dataclass +class IntxWeightOnlyConfig(AOBaseConfig): + """ + Configuration for quantizing weights to torch.intx, with 1 <= x <= 8. + Weights are quantized with scales and optionally zeros (controlled by zero_point_domain) in a groupwise or channelwise + manner using the number of bits specified by weight_dtype. + args: + weight_dtype: The dtype to use for weight quantization. Must be torch.intx, where 1 <= x <= 8. + torch.intx with x < 8 requires TORCH_VERSION_AT_LEAST_2_6 + granularity: The granularity to use for weight quantization. Must be PerGroup or PerAxis(0). + zero_point_domain: The zero point domain to use for weight quantization. + Must be ZeroPointDomain.INT (if quantized weights have zeros) or ZeroPointDomain.NONE (if quantized weights do not have zeros). + mapping_type: The type of mapping to use for the weight quantization. + Must be one of MappingType.ASYMMETRIC or MappingType.SYMMETRIC. + scale_dtype: The dtype to use for the weight scale. + layout: The layout to use for the packed weight tensor: + - QDQLayout: this layout is designed for export to ExecuTorch.this layout represents the quantization with Q/DQ quant primitives, + and is intended for export applications like ExecuTorch. + """ + + weight_dtype: torch.dtype = torch.int8 + granularity: Granularity = PerAxis(0) + zero_point_domain: ZeroPointDomain = ZeroPointDomain.NONE + mapping_type: MappingType = MappingType.SYMMETRIC + scale_dtype: Optional[torch.dtype] = None + layout: Layout = QDQLayout() + + def __post_init__(self): + assert TORCH_VERSION_AT_LEAST_2_6, "IntxWeightOnlyConfig requires torch 2.6+" + assert ( + self.weight_dtype in [getattr(torch, f"int{b}") for b in range(1, 9)] + ), f"weight_dtype must be torch.intx, where 1 <= x <= 8, but got {self.weight_dtype}" + assert isinstance( + self.granularity, (PerAxis, PerGroup) + ), f"granularity must be PerAxis or PerGroup, but got {self.granularity}" + if isinstance(self.granularity, PerAxis): + assert ( + self.granularity.axis == 0 + ), f"axis must be 0 with PerAxis, but got {self.granularity.axis}" + assert ( + self.zero_point_domain in [ZeroPointDomain.INT, ZeroPointDomain.NONE] + ), f"zero_point_domain must be ZeroPointDomain.INT or ZeroPointDomain.NONE, but got {self.zero_point_domain}" + assert ( + self.mapping_type in [MappingType.ASYMMETRIC, MappingType.SYMMETRIC] + ), f"mapping_type must be MappingType.ASYMMETRIC or MappingType.SYMMETRIC, but got {self.mapping_type}" + if self.mapping_type == MappingType.SYMMETRIC: + assert ( + self.zero_point_domain == ZeroPointDomain.NONE + ), f"zero_point_domain must be ZeroPointDomain.NONE when mapping_type is MappingType.SYMMETRIC, but got {self.zero_point_domain}" + + +@register_quantize_module_handler(IntxWeightOnlyConfig) +def _intx_weight_only_transform( + module: torch.nn.Module, config: IntxWeightOnlyConfig +) -> torch.nn.Module: + weight = module.weight + weight_dtype = config.weight_dtype + granularity = config.granularity + zero_point_domain = config.zero_point_domain + mapping_type = config.mapping_type + scale_dtype = config.scale_dtype + layout = config.layout + + assert ( + weight.dim() == 2 + ), f"IntxWeightOnlyConfig only works for 2-d Tensor, got: {weight.dim()}" + if isinstance(granularity, PerGroup): + group_size = granularity.group_size + elif isinstance(granularity, PerAxis): + assert ( + granularity.axis == 0 + ), f"axis must be 0 with PerAxis, but got {granularity.axis}" + group_size = weight.shape[-1] + else: + raise ValueError(f"granularity must be PerGroup or PerAxis, got {granularity}") + + quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype] + has_weight_zeros = zero_point_domain == ZeroPointDomain.INT + weight = to_affine_quantized_intx( + input_float=weight, + mapping_type=mapping_type, + block_size=(1, group_size), + target_dtype=torch.int8, + quant_min=quant_min, + quant_max=quant_max, + eps=torch.finfo(torch.float32).eps, + scale_dtype=scale_dtype, + zero_point_dtype=torch.int8 if has_weight_zeros else None, + preserve_zero=has_weight_zeros or (mapping_type == MappingType.SYMMETRIC), + zero_point_domain=zero_point_domain, + _layout=layout, + ) + module.weight = torch.nn.Parameter(weight, requires_grad=False) + return module + + @dataclass class FPXWeightOnlyConfig(AOBaseConfig): """Sub-byte floating point dtypes defined by `ebits`: exponent bits and `mbits`: mantissa bits