Skip to content

Commit 8c48af8

Browse files
vkuzofacebook-github-bot
authored andcommitted
pytorch docs: add fake_quantize functions documentation (pytorch#51748)
Summary: Pull Request resolved: pytorch#51748 Adding docs for `fake_quantize_per_tensor_affine` and `fake_quantize_per_channel_affine` functions. Note: not documenting `fake_quantize_per_tensor_affine_cachemask` and `fake_quantize_per_channel_affine_cachemask` since they are implementation details of `fake_quantize_per_tensor_affine` and `fake_quantize_per_channel_affine`, and do not need to be exposed to the user at the moment. Test Plan: Build the docs locally on Mac OS, it looks good Reviewed By: supriyar Differential Revision: D26270514 Pulled By: vkuzo fbshipit-source-id: 8e3c9815a12a3427572cb4d34a779e9f5e4facdd
1 parent ececbcf commit 8c48af8

File tree

2 files changed

+87
-0
lines changed

2 files changed

+87
-0
lines changed

docs/source/torch.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,8 @@ Pointwise Ops
298298
exp
299299
exp2
300300
expm1
301+
fake_quantize_per_channel_affine
302+
fake_quantize_per_tensor_affine
301303
fix
302304
float_power
303305
floor

torch/_torch_docs.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9232,6 +9232,91 @@ def merge_dicts(*dicts):
92329232
tensor([ 3., 0., -0., -0.])
92339233
""".format(**common_args))
92349234

9235+
add_docstr(torch.fake_quantize_per_tensor_affine,
9236+
r"""
9237+
fake_quantize_per_tensor_affine(input, scale, zero_point, quant_min, quant_max) -> Tensor
9238+
9239+
Returns a new tensor with the data in :attr:`input` fake quantized using :attr:`scale`,
9240+
:attr:`zero_point`, :attr:`quant_min` and :attr:`quant_max`.
9241+
9242+
.. math::
9243+
\text{output} = min(
9244+
\text{quant\_max},
9245+
max(
9246+
\text{quant\_min},
9247+
\text{std::nearby\_int}(\text{input} / \text{scale}) + \text{zero\_point}
9248+
)
9249+
)
9250+
9251+
Args:
9252+
input (Tensor): the input value(s), in ``torch.float32``.
9253+
scale (double): quantization scale
9254+
zero_point (int64): quantization zero_point
9255+
quant_min (int64): lower bound of the quantized domain
9256+
quant_max (int64): upper bound of the quantized domain
9257+
9258+
Returns:
9259+
Tensor: A newly fake_quantized tensor
9260+
9261+
Example::
9262+
9263+
>>> x = torch.randn(4)
9264+
>>> x
9265+
tensor([ 0.0552, 0.9730, 0.3973, -1.0780])
9266+
>>> torch.fake_quantize_per_tensor_affine(x, 0.1, 0, 0, 255)
9267+
tensor([0.1000, 1.0000, 0.4000, 0.0000])
9268+
""")
9269+
9270+
add_docstr(torch.fake_quantize_per_channel_affine,
9271+
r"""
9272+
fake_quantize_per_channel_affine(input, scale, zero_point, quant_min, quant_max) -> Tensor
9273+
9274+
Returns a new tensor with the data in :attr:`input` fake quantized per channel using :attr:`scale`,
9275+
:attr:`zero_point`, :attr:`quant_min` and :attr:`quant_max`, across the channel specified by :attr:`axis`.
9276+
9277+
.. math::
9278+
\text{output} = min(
9279+
\text{quant\_max},
9280+
max(
9281+
\text{quant\_min},
9282+
\text{std::nearby\_int}(\text{input} / \text{scale}) + \text{zero\_point}
9283+
)
9284+
)
9285+
9286+
Args:
9287+
input (Tensor): the input value(s), in ``torch.float32``.
9288+
scale (Tensor): quantization scale, per channel
9289+
zero_point (Tensor): quantization zero_point, per channel
9290+
axis (int32): channel axis
9291+
quant_min (int64): lower bound of the quantized domain
9292+
quant_max (int64): upper bound of the quantized domain
9293+
9294+
Returns:
9295+
Tensor: A newly fake_quantized per channel tensor
9296+
9297+
Example::
9298+
9299+
>>> x = torch.randn(2, 2, 2)
9300+
>>> x
9301+
tensor([[[-0.2525, -0.0466],
9302+
[ 0.3491, -0.2168]],
9303+
9304+
[[-0.5906, 1.6258],
9305+
[ 0.6444, -0.0542]]])
9306+
>>> scales = (torch.randn(2) + 1) * 0.05
9307+
>>> scales
9308+
tensor([0.0475, 0.0486])
9309+
>>> zero_points = torch.zeros(2).to(torch.long)
9310+
>>> zero_points
9311+
tensor([0, 0])
9312+
>>> torch.fake_quantize_per_channel_affine(x, scales, zero_points, 1, 0, 255)
9313+
tensor([[[0.0000, 0.0000],
9314+
[0.3405, 0.0000]],
9315+
9316+
[[0.0000, 1.6134],
9317+
[0.6323, 0.0000]]])
9318+
""")
9319+
92359320
add_docstr(torch.fix,
92369321
r"""
92379322
fix(input, *, out=None) -> Tensor

0 commit comments

Comments
 (0)