@@ -1774,6 +1774,29 @@ def reference_sigmoid(x):
1774
1774
return (1 / (1 + np .exp (- x )))
1775
1775
return scipy .special .expit (x )
1776
1776
1777
+ def reference_lgamma (x ):
1778
+ # scipy.special.gammaln returns `-inf` when input is `-inf`.
1779
+ # While Pytorch, C and C++, all return `inf` when input is `-inf`.
1780
+ # Reference:
1781
+ # https://en.cppreference.com/w/cpp/numeric/math/lgamma
1782
+ # https://en.cppreference.com/w/c/numeric/math/lgamma
1783
+
1784
+ # To handle the above discrepancy,
1785
+ # we replace -inf with inf so values
1786
+ # that were originally -inf map to inf as expected
1787
+ if x .dtype .kind == 'f' :
1788
+ x = np .where (x == float ('-inf' ), np .array (float ('inf' ), dtype = x .dtype ), x )
1789
+
1790
+ out = scipy .special .gammaln (x )
1791
+
1792
+ if x .dtype == np .float16 :
1793
+ # `scipy.special.gammaln` returns output of float32 when input is float16,
1794
+ # while `torch.lgamma` preserves `float16`. But due to smaller range of float16,
1795
+ # Pytorch version outputs `inf` while SciPy returns finite values.
1796
+ out = out .astype (np .float16 )
1797
+
1798
+ return out
1799
+
1777
1800
op_db_scipy_reference : List [OpInfo ] = [
1778
1801
UnaryUfuncInfo ('sigmoid' ,
1779
1802
ref = reference_sigmoid ,
@@ -1851,6 +1874,27 @@ def reference_sigmoid(x):
1851
1874
dtypes = [torch .bfloat16 ]),
1852
1875
)
1853
1876
),
1877
+ UnaryUfuncInfo ('lgamma' ,
1878
+ ref = reference_lgamma ,
1879
+ decorators = (precisionOverride ({torch .float16 : 7e-1 }),),
1880
+ dtypes = all_types_and (torch .bool ),
1881
+ dtypesIfCPU = all_types_and (torch .bool , torch .bfloat16 ),
1882
+ dtypesIfCUDA = all_types_and (torch .bool , torch .half ),
1883
+ skips = (
1884
+ # Reference: https://github.com/pytorch/pytorch/pull/50140#discussion_r552615345
1885
+ SkipInfo ('TestUnaryUfuncs' , 'test_reference_numerics' ,
1886
+ dtypes = [torch .bfloat16 ]),
1887
+ # Reference: https://github.com/pytorch/pytorch/pull/50140#issuecomment-756150214
1888
+ SkipInfo ('TestUnaryUfuncs' , 'test_reference_numerics' ,
1889
+ dtypes = [torch .float32 , torch .float64 ], active_if = IS_WINDOWS ),
1890
+ # Backward of `lgamma` uses `digamma` but `digamma`
1891
+ # is not implemented for `BFloat16`
1892
+ # Error Raised:
1893
+ # RuntimeError: "digamma" not implemented for 'BFloat16'
1894
+ SkipInfo ('TestCommon' , 'test_variant_consistency_jit' ,
1895
+ dtypes = [torch .bfloat16 ]),
1896
+ ),
1897
+ safe_casts_outputs = True ),
1854
1898
OpInfo ('xlogy' ,
1855
1899
dtypes = all_types_and (torch .bool ),
1856
1900
dtypesIfCPU = all_types_and (torch .bool , torch .half , torch .bfloat16 ),
0 commit comments