Skip to content

Commit e544d74

Browse files
anjali411facebook-github-bot
authored andcommitted
[CPU] Add torch.trace for complex tensors (pytorch#50380)
Summary: Pull Request resolved: pytorch#50380 Test Plan: Imported from OSS Reviewed By: ezyang Differential Revision: D25949361 Pulled By: anjali411 fbshipit-source-id: 9910bc5b532c9bf3add530221d643b2c41c62d01
1 parent 2c3c2a4 commit e544d74

File tree

3 files changed

+29
-11
lines changed

3 files changed

+29
-11
lines changed

aten/src/ATen/native/ReduceOps.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -502,9 +502,12 @@ static Tensor& prod_out_impl(Tensor& result, const Tensor& self, IntArrayRef dim
502502
// see https://github.com/pytorch/pytorch/pull/47305,
503503
Tensor trace_cpu(const Tensor& self) {
504504
Tensor result;
505+
// Returns the ScalarType of the self tensor if the tensor is non integral type
506+
// In the case, self is an integer type tensor, at::kLong is return since promote_integers
507+
// is set to true
505508
ScalarType dtype = get_dtype(result, self, c10::nullopt, true);
506509
result = at::empty({}, self.options().dtype(dtype));
507-
AT_DISPATCH_ALL_TYPES(self.scalar_type(), "trace", [&] {
510+
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(self.scalar_type(), "trace", [&] {
508511
using accscalar_t = at::acc_type<scalar_t, false>;
509512
accscalar_t sum = 0;
510513
const auto* t_data = self.data_ptr<scalar_t>();
@@ -521,12 +524,11 @@ Tensor trace_cpu(const Tensor& self) {
521524
sum += t_data[i * (t_stride_0 + t_stride_1)];
522525
}
523526

524-
// all integer types get promoted to kLong
525-
if (result.scalar_type() == at::kLong) {
526-
*result.data_ptr<int64_t>() = sum;
527-
} else {
528-
*result.data_ptr<scalar_t>() = sum;
529-
}
527+
c10::guts::if_constexpr<std::is_integral<accscalar_t>::value>(
528+
// all integer types get promoted to kLong
529+
[&] (auto _) { *result.data_ptr<int64_t>() = _(sum); }, // then-case, invalid for non-integral types
530+
[&] (auto _) { *result.data_ptr<scalar_t>() = _(sum); } // else-case, invalid for integral types
531+
);
530532
});
531533

532534
return result;
@@ -843,7 +845,7 @@ Tensor any(const Tensor& self) {
843845
"any only supports CPU AND CUDA device type, got: ", self.device().type());
844846
TORCH_CHECK(self.layout() == Layout::Strided || self.layout() == Layout::Sparse,
845847
"any only supports strided AND sparse layout, got: ", self.layout());
846-
848+
847849
// Refer [all, any : uint8 compatibility]
848850
Tensor result;
849851
ScalarType out_dtype;

test/test_torch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6799,7 +6799,6 @@ def inner(self, device, dtype):
67996799
1e-5, 1e-5, 1e-5, torch.testing.get_all_dtypes(include_complex=False, include_bool=False), _cpu_types, False),
68006800
('topk', 'dim_desc_sort', _small_3d_unique, lambda t, d: [2, 1, True, True],
68016801
1e-5, 1e-5, 1e-5, torch.testing.get_all_dtypes(include_complex=False, include_bool=False), _cpu_types, False),
6802-
('trace', '', _medium_2d, lambda t, d: [], 1e-3, 1e-5, 1e-5, _types, _cpu_types, False),
68036802
('tril', '', _medium_2d, lambda t, d: [],),
68046803
('tril', 'zero_stride', _medium_2d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
68056804
('tril', 'positive', _medium_2d, lambda t, d: [2], ),

torch/testing/_internal/common_methods_invocations.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from torch.testing import \
1515
(make_non_contiguous, _dispatch_dtypes, floating_types, floating_types_and,
1616
floating_and_complex_types, floating_and_complex_types_and,
17-
all_types_and_complex_and, all_types_and)
17+
all_types_and_complex_and, all_types_and, all_types_and_complex)
1818
from torch.testing._internal.common_device_type import \
1919
(skipCUDAIfNoMagma, skipCPUIfNoLapack, skipCPUIfNoMkl, skipCUDAIfRocm,
2020
expectedAlertNondeterministic, precisionOverride)
@@ -489,6 +489,11 @@ def sample_inputs_xlogy(self, device, dtype, requires_grad):
489489
low=0, high=None,
490490
requires_grad=requires_grad))),)
491491

492+
def sample_inputs_trace(self, device, dtype, requires_grad):
493+
return (SampleInput((make_tensor((S, S), device, dtype,
494+
low=None, high=None,
495+
requires_grad=requires_grad))),)
496+
492497
def sample_inputs_linalg_inv(op_info, device, dtype, requires_grad=False):
493498
"""
494499
This function generates always invertible input for torch.linalg.inv using
@@ -1788,6 +1793,19 @@ def reference_sigmoid(x):
17881793
supports_tensor_out=True,
17891794
safe_casts_outputs=True,
17901795
sample_inputs_func=sample_inputs_xlogy),
1796+
OpInfo('trace',
1797+
dtypes=all_types_and_complex(),
1798+
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half),
1799+
test_inplace_grad=False,
1800+
supports_tensor_out=False,
1801+
# Reference: https://github.com/pytorch/pytorch/issues/50381
1802+
test_complex_grad=False,
1803+
sample_inputs_func=sample_inputs_trace,
1804+
skips=(
1805+
SkipInfo('TestCommon', 'test_variant_consistency_jit',
1806+
dtypes=[torch.complex64, torch.complex128]),
1807+
SkipInfo('TestCommon', 'test_variant_consistency_eager',
1808+
dtypes=[torch.complex64, torch.complex128]))),
17911809
]
17921810
op_db = op_db + op_db_scipy_reference
17931811

@@ -2494,7 +2512,6 @@ def method_tests():
24942512
('triu', (S, M, M), NO_ARGS, 'batched'),
24952513
('triu', (S, M, M), (2,), 'batched_idx'),
24962514
('triu', (3, 3, S, S), NO_ARGS, 'more_batched'),
2497-
('trace', (M, M), NO_ARGS),
24982515
('cross', (S, 3), ((S, 3),)),
24992516
('cross', (S, 3, S), ((S, 3, S), 1), 'dim'),
25002517
('index_add', (S, S), (0, index_variable(2, S), (2, S)), 'dim', (), [0]),

0 commit comments

Comments
 (0)