Skip to content

Commit 6d09809

Browse files
kshitij12345facebook-github-bot
authored andcommitted
[numpy] torch.lgamma: promote integer inputs to float (pytorch#50140)
Summary: Reference: pytorch#42515 Pull Request resolved: pytorch#50140 Reviewed By: mrshenli Differential Revision: D25951094 Pulled By: mruberry fbshipit-source-id: e53f1dbddff889710f05d43dbc9587382d3decb0
1 parent dd1a97b commit 6d09809

File tree

7 files changed

+57
-46
lines changed

7 files changed

+57
-46
lines changed

aten/src/ATen/native/UnaryOps.cpp

+3-32
Original file line numberDiff line numberDiff line change
@@ -650,38 +650,9 @@ Tensor& mvlgamma_(Tensor& self, int64_t p) {
650650
return self.copy_(args.lgamma_().sum(-1).add_(p * (p - 1) * std::log(c10::pi<double>) / 4.));
651651
}
652652

653-
// NB: If you use this macro, you may also need to add a CUDA forwarding
654-
// stub in CUDAUnaryOps
655-
656-
#define IMPLEMENT_UNARY_OP_CORE(op) \
657-
Tensor op(const Tensor& self) { \
658-
Tensor result = at::empty({0}, self.options()); \
659-
at::op##_out(result, self); \
660-
return result; \
661-
}
662-
663-
#define IMPLEMENT_UNARY_OP_OUT_INPLACE(op, prefix, device) \
664-
Tensor& _##op##__##prefix(Tensor& self) { \
665-
return at::op##_out(self, self); \
666-
} \
667-
Tensor& _##op##_out_##prefix(Tensor& result, const Tensor& self) { \
668-
checkDeviceType(#op, result, DeviceType::device); \
669-
checkLayout(#op, result, Layout::Strided); \
670-
auto iter = TensorIterator::unary_op(result, self); \
671-
op##_stub(iter.device_type(), iter); \
672-
return result; \
673-
}
674-
675-
#define IMPLEMENT_UNARY_OP_VEC(op) \
676-
IMPLEMENT_UNARY_OP_CORE(op) \
677-
IMPLEMENT_UNARY_OP_OUT_INPLACE(op, cpu, CPU)
678-
679-
#define IMPLEMENT_UNARY_OP_VEC_CUDA(op) \
680-
IMPLEMENT_UNARY_OP_CORE(op) \
681-
IMPLEMENT_UNARY_OP_OUT_INPLACE(op, cpu, CPU) \
682-
IMPLEMENT_UNARY_OP_OUT_INPLACE(op, cuda, CUDA)
683-
684-
IMPLEMENT_UNARY_OP_VEC_CUDA(lgamma)
653+
Tensor& lgamma_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, lgamma_stub); }
654+
Tensor lgamma(const Tensor& self) { return unary_op_impl_float(self, lgamma_stub); }
655+
Tensor& lgamma_(Tensor& self) { return unary_op_impl_(self, at::lgamma_out); }
685656

686657
DEFINE_DISPATCH(abs_stub);
687658
DEFINE_DISPATCH(angle_stub);

aten/src/ATen/native/cuda/UnaryGammaKernels.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ void polygamma_kernel_cuda(TensorIterator& iter, int64_t n) {
4141
}
4242

4343
void lgamma_kernel_cuda(TensorIterator& iter) {
44-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "lgamma_cuda", [&]() {
44+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "lgamma_cuda", [&]() {
4545
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
4646
return ::lgamma(a);
4747
});

aten/src/ATen/native/native_functions.yaml

+6-8
Original file line numberDiff line numberDiff line change
@@ -5148,12 +5148,6 @@
51485148
dispatch:
51495149
CPU, CUDA: __irshift__
51505150

5151-
- func: lgamma_(Tensor(a!) self) -> Tensor(a!)
5152-
variants: method
5153-
dispatch:
5154-
CPU: _lgamma__cpu
5155-
CUDA: _lgamma__cuda
5156-
51575151
- func: atan2_(Tensor(a!) self, Tensor other) -> Tensor(a!)
51585152
variants: method
51595153
dispatch:
@@ -5979,8 +5973,12 @@
59795973
- func: lgamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
59805974
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
59815975
dispatch:
5982-
CPU: _lgamma_out_cpu
5983-
CUDA: _lgamma_out_cuda
5976+
CPU, CUDA: lgamma_out
5977+
5978+
- func: lgamma_(Tensor(a!) self) -> Tensor(a!)
5979+
variants: method
5980+
dispatch:
5981+
CPU, CUDA: lgamma_
59845982

59855983
- func: lgamma(Tensor self) -> Tensor
59865984
variants: method, function

test/test_torch.py

-1
Original file line numberDiff line numberDiff line change
@@ -6848,7 +6848,6 @@ def inner(self, device, dtype):
68486848
('round', '', _small_3d, lambda t, d: [], 1e-5, 1e-2, 1e-5, _float_types, [torch.bfloat16]),
68496849
('trunc', '', _small_3d, lambda t, d: [], 1e-5, 1e-2, 1e-5, _float_types, [torch.bfloat16]),
68506850
('ceil', '', _small_3d, lambda t, d: [], 1e-5, 1e-2, 1e-5, _float_types, [torch.bfloat16]),
6851-
('lgamma', '', _small_3d, lambda t, d: [], 1e-2, 1e-1, 1e-5, _float_types_no_half, [torch.bfloat16]),
68526851
]
68536852

68546853
# Creates and decorates a generic test and adds it to the class.

test/test_unary_ufuncs.py

-2
Original file line numberDiff line numberDiff line change
@@ -1684,8 +1684,6 @@ def _medium_2d(dtype, device):
16841684
_TorchMathTestMeta('frac', reffn='fmod', refargs=lambda x: (x.numpy(), 1)),
16851685
_TorchMathTestMeta('trunc'),
16861686
_TorchMathTestMeta('round'),
1687-
# FIXME lgamma produces different result compared to scipy at -inf
1688-
_TorchMathTestMeta('lgamma', reffn='gammaln', ref_backend='scipy', replace_inf_with_nan=True),
16891687
_TorchMathTestMeta('polygamma', args=[0], substr='_0', reffn='polygamma',
16901688
refargs=lambda x: (0, x.numpy()), input_fn=_generate_gamma_input, inputargs=[False],
16911689
ref_backend='scipy'),

torch/csrc/jit/tensorexpr/kernel.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -1350,8 +1350,9 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {
13501350
} break;
13511351

13521352
case aten::lgamma: {
1353-
return computeOneOperand(
1354-
"aten_lgamma", v, [](const ExprHandle& a) { return lgamma(a); });
1353+
return computeOneOperand("aten_lgamma", v, [](const ExprHandle& a) {
1354+
return lgamma(promoteIntegerToDefaultType(a));
1355+
});
13551356
} break;
13561357

13571358
case prim::ConstantChunk: {

torch/testing/_internal/common_methods_invocations.py

+44
Original file line numberDiff line numberDiff line change
@@ -1774,6 +1774,29 @@ def reference_sigmoid(x):
17741774
return (1 / (1 + np.exp(-x)))
17751775
return scipy.special.expit(x)
17761776

1777+
def reference_lgamma(x):
1778+
# scipy.special.gammaln returns `-inf` when input is `-inf`.
1779+
# While Pytorch, C and C++, all return `inf` when input is `-inf`.
1780+
# Reference:
1781+
# https://en.cppreference.com/w/cpp/numeric/math/lgamma
1782+
# https://en.cppreference.com/w/c/numeric/math/lgamma
1783+
1784+
# To handle the above discrepancy,
1785+
# we replace -inf with inf so values
1786+
# that were originally -inf map to inf as expected
1787+
if x.dtype.kind == 'f':
1788+
x = np.where(x == float('-inf'), np.array(float('inf'), dtype=x.dtype), x)
1789+
1790+
out = scipy.special.gammaln(x)
1791+
1792+
if x.dtype == np.float16:
1793+
# `scipy.special.gammaln` returns output of float32 when input is float16,
1794+
# while `torch.lgamma` preserves `float16`. But due to smaller range of float16,
1795+
# Pytorch version outputs `inf` while SciPy returns finite values.
1796+
out = out.astype(np.float16)
1797+
1798+
return out
1799+
17771800
op_db_scipy_reference: List[OpInfo] = [
17781801
UnaryUfuncInfo('sigmoid',
17791802
ref=reference_sigmoid,
@@ -1851,6 +1874,27 @@ def reference_sigmoid(x):
18511874
dtypes=[torch.bfloat16]),
18521875
)
18531876
),
1877+
UnaryUfuncInfo('lgamma',
1878+
ref=reference_lgamma,
1879+
decorators=(precisionOverride({torch.float16: 7e-1}),),
1880+
dtypes=all_types_and(torch.bool),
1881+
dtypesIfCPU=all_types_and(torch.bool, torch.bfloat16),
1882+
dtypesIfCUDA=all_types_and(torch.bool, torch.half),
1883+
skips=(
1884+
# Reference: https://github.com/pytorch/pytorch/pull/50140#discussion_r552615345
1885+
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
1886+
dtypes=[torch.bfloat16]),
1887+
# Reference: https://github.com/pytorch/pytorch/pull/50140#issuecomment-756150214
1888+
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
1889+
dtypes=[torch.float32, torch.float64], active_if=IS_WINDOWS),
1890+
# Backward of `lgamma` uses `digamma` but `digamma`
1891+
# is not implemented for `BFloat16`
1892+
# Error Raised:
1893+
# RuntimeError: "digamma" not implemented for 'BFloat16'
1894+
SkipInfo('TestCommon', 'test_variant_consistency_jit',
1895+
dtypes=[torch.bfloat16]),
1896+
),
1897+
safe_casts_outputs=True),
18541898
OpInfo('xlogy',
18551899
dtypes=all_types_and(torch.bool),
18561900
dtypesIfCPU=all_types_and(torch.bool, torch.half, torch.bfloat16),

0 commit comments

Comments
 (0)