Skip to content

9080 expose mat mul precision #9081

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
May 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions test/test_mat_mul_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""Numeric tests for default precision of mat mul."""

import unittest

import torch
import torch_xla
import torch_xla.backends

import test_utils


class TestMatMulPrecision(unittest.TestCase):

def _make_input(self):
eye = torch.eye(1024, device='cpu', dtype=torch.float64)
rand_ = torch.testing.make_tensor((1024, 1024),
dtype=torch.float64,
device="cpu",
low=0.99,
high=1.01)
return eye * rand_

# TODO: Figure out why either PT/XLA or unittest
# is unable to successfully run this test in a parameterized way.
# https://github.com/pytorch/xla/issues/9129
@unittest.skipIf(not test_utils.is_on_tpu(), 'Skipping, not on TPU.')
@unittest.expectedFailure
def test_all(self):
# The number of bit of precise mantissa expected in the result.
parameters = [
('highest', 22),
('high', 14),
('default', 8),
]
# Although pytest has a slightly more elegant parameterized testing function,
# all TPU tests user unittest.
for i, (precision, bits) in enumerate(parameters):
with self.subTest(precision=precision, bits=bits):
self._test_parameterized(precision, bits)

@unittest.skipIf(not test_utils.is_on_tpu(), 'Skipping, not on TPU.')
def test_highest(self):
self._test_parameterized('highest', 22)

@unittest.skipIf(not test_utils.is_on_tpu(), 'Skipping, not on TPU.')
def test_high(self):
self._test_parameterized('high', 14)

@unittest.skipIf(not test_utils.is_on_tpu(), 'Skipping, not on TPU.')
def test_default(self):
self._test_parameterized('default', 8)

# DO NOT add epsilons to this test. These tests must be numerically exact.
def _test_parameterized(self, precision, bits):
# Arrange
torch_xla.backends.set_mat_mul_precision(precision)

# Diagonal matrices force mat mul through MXU
# but require only one non-zero accumulation.
x = self._make_input()
y = self._make_input()
reference_float64 = torch.matmul(x, y)

# TODO: Justify this logic. Why isn't it Why is it not
# 1 - ((2**8 - 1) / 2**8)**2 (equation stated by per TPU expert)?
widest_atol = torch.tensor(
-1 + ((2**(bits) + 1) / 2**bits)**2, dtype=torch.float64)

narrowest_atol = widest_atol / 4.0

x = x.to(torch.float32).to('xla')
y = y.to(torch.float32).to('xla')

# Act
actual = torch.matmul(x, y).to('cpu').to(torch.float64)

# Disable rtol, we know exactly the atol for default, high, and highest.
torch.testing.assert_close(
actual,
reference_float64,
rtol=0.0,
atol=widest_atol,
)

with self.assertRaises(AssertionError):
torch.testing.assert_close(
actual,
reference_float64,
rtol=0.0,
atol=narrowest_atol,
)

assert not torch.equal(actual, reference_float64), (
"Actual product and reference product should not be closer than equal, "
f"but they are: {torch.diag(actual)} == {torch.diag(reference_float64)}"
)


# There is no main function. This is designed to be run from
# python -m unittest ...
62 changes: 62 additions & 0 deletions test/test_mat_mul_precision_get_and_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Tests for get/set_mat_mul_precision from init_python_bindings.cpp"""

import sys
import unittest

import torch
import torch_xla
import torch_xla.backends


class TestMatMulPrecisionGetAndSet(unittest.TestCase):

def setUp(self):
self._original = torch_xla.backends.get_mat_mul_precision()
torch.set_printoptions(precision=20)
torch_xla.sync()

def tearDown(self):
torch_xla.backends.set_mat_mul_precision(self._original)
torch.set_printoptions(profile="default")
torch_xla.sync()

def test_set_mat_mul_precision_error(self):
# Assert
with self.assertRaises(ValueError):
# Act
torch_xla.backends.set_mat_mul_precision('BAD VALUE')

def test_get_and_set_mat_mul_precision_default(self):
# Arrange
torch_xla.backends.set_mat_mul_precision('default')

# Act
status = torch_xla.backends.get_mat_mul_precision()

# Assert
self.assertEqual(status, 'default')

def test_get_and_set_mat_mul_precision_high(self):
# Arrange
torch_xla.backends.set_mat_mul_precision('high')

# Act
status = torch_xla.backends.get_mat_mul_precision()

# Assert
self.assertEqual(status, 'high')

def test_get_and_set_mat_mul_precision_highest(self):
# Arrange
torch_xla.backends.set_mat_mul_precision('highest')

# Act
status = torch_xla.backends.get_mat_mul_precision()

# Assert
self.assertEqual(status, 'highest')


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
6 changes: 6 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.utils.utils as xu
import torch_xla.runtime as xr


def _set_rng_seed(seed):
Expand Down Expand Up @@ -420,3 +421,8 @@ def temporary_env(**kwargs):
else:
# Restore the original value
os.environ[key] = old_value


# Taken from test_operations.py
def is_on_tpu():
return 'XRT_TPU_CONFIG' in os.environ or xr.device_type() == 'TPU'
5 changes: 5 additions & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ TEST_CDIR="$(dirname "$CDIR")"
source "${TEST_CDIR}/utils/run_tests_utils.sh"

# TODO: merge with other run_tests
(cd $TEST_CDIR && python3 -m unittest test_mat_mul_precision.TestMatMulPrecision.test_high)
(cd $TEST_CDIR && python3 -m unittest test_mat_mul_precision.TestMatMulPrecision.test_default)
(cd $TEST_CDIR && python3 -m unittest test_mat_mul_precision.TestMatMulPrecision.test_highest)
(cd $TEST_CDIR && python3 -m unittest test_mat_mul_precision.TestMatMulPrecision.test_all)
python3 "$TEST_CDIR/test_mat_mul_precision_get_and_set.py"
python3 "$TEST_CDIR/test_operations.py" -v
python3 "$TEST_CDIR/pjrt/test_runtime_tpu.py"
python3 "$TEST_CDIR/pjrt/test_collective_ops_tpu.py"
Expand Down
78 changes: 78 additions & 0 deletions torch_xla/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""torch_xla.backends controls the behavior of the XLA backend.

This subpackage parallels the torch.backends.{cuda, cpu, mps, etc}
subpackages in PyTorch.
"""

# See https://github.com/pytorch/pytorch/blob/main/torch/backends/mps/__init__.py
# for an example of how backends are implemented in PyTorch
# in the __init__.py file, despite general style guidelines against this.

# Literal is available from Python 3.8,
# matching the Python versions for PyTorch and PyTorch/XLA.
from typing import Final, Literal, TypeAlias

import torch_xla

__all__ = ["set_mat_mul_precision", "get_mat_mul_precision"]

# Valid values for get_mat_mul_precision/set_mat_mul_precision
# Note: it is idiomatic to PyTorch to use strings rather than enums.
# See https://github.com/pytorch/pytorch/blob/v2.7.0/torch/backends/cpu/__init__.py#L9

_DEFAULT: Final = "default"
_HIGH: Final = "high"
_HIGHEST: Final = "highest"

# Use of variables with Final typehint instead of literals is valid.
_PrecisionType: TypeAlias = Literal[
_DEFAULT, _HIGH, _HIGHEST] # pyright: ignore[reportInvalidTypeForm]


# Some of this description adapted from Jax documentation.
# TODO: Once the numerics tutorial is released, link from this docstring.
def set_mat_mul_precision(precision: _PrecisionType) -> None:
"""Control the default matmul and conv precision for 32bit inputs.

Some platforms, like TPU, offer configurable precision levels for
matrix multiplication and convolution computations,
trading off accuracy for speed.

This option controls the default precision level for
computations involved in matrix multiplication and convolution on
32bit inputs. The levels describe the precision at
which scalar products are computed.

On a TPU:
* `default` is the fastest and least precise,
downcasting an FP32 to BF16 before multiplying.

* `high` takes three passes and generates approximately 14 bits of
precision.

* `highest` is the most precise, and the slowest. It takes six
passes and generates approximately 22 bits of precision.

Args:
precision (str): The precision to set for matrix multiplication.
Must be one of 'default', 'high', or 'highest'.
"""
if precision not in [_DEFAULT, _HIGH, _HIGHEST]:
raise ValueError(f"Invalid precision: {precision}. "
f"Must be one of {_DEFAULT}, {_HIGH}, {_HIGHEST}.")

torch_xla._XLAC._xla_set_mat_mul_precision(precision)


def get_mat_mul_precision() -> _PrecisionType:
"""Get the current mat mul precision for 32bit inputs.

Returns:
str: The current precision setting for matrix multiplication,
one of 'default', 'high', or 'highest'.
"""
precision = torch_xla._XLAC._xla_get_mat_mul_precision()
assert precision in [_DEFAULT, _HIGH, _HIGHEST
], (f"Invalid precision: {precision}. "
f"Must be one of {_DEFAULT}, {_HIGH}, {_HIGHEST}.")
return precision
4 changes: 4 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2124,6 +2124,10 @@ void InitXlaModuleBindings(py::module m) {
ConsumeValue(xla::StringToPrecision(mat_mul_precision));
XlaHelpers::set_mat_mul_precision(precision);
});
m.def("_xla_get_mat_mul_precision", []() {
xla::PrecisionConfig::Precision precision = XlaHelpers::mat_mul_precision();
return xla::PrecisionToString(precision);
});

py::class_<xla::XlaBuilder, op_builder::BuilderPtr>(m, "XlaBuilder");
py::class_<op_builder::Op, op_builder::OpPtr>(m, "XlaOp");
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/xla_op_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,10 @@ xla::PrecisionConfig DotPrecisonConfig(py::dict args) {
precision = xla::PrecisionConfig::HIGH;
} else if (*arg_precision_config == "highest") {
precision = xla::PrecisionConfig::HIGHEST;
} else {
XLA_ERROR() << "Invalid precision config in args: "
<< *arg_precision_config
<< " (valid values: default, high, highest)";
}
}
return XlaHelpers::BuildPrecisionConfig(precision);
Expand Down