Skip to content

Commit 6837d1d

Browse files
vfdev-5zou3519
authored andcommitted
[functorch] Added addr op decomposition (pytorch/functorch#323)
Description: - Added addr op decomposition - Updated tests Related to pytorch/functorch#240
1 parent 8478504 commit 6837d1d

File tree

3 files changed

+9
-2
lines changed

3 files changed

+9
-2
lines changed

functorch/functorch/csrc/BatchRulesBinaryOps.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,14 @@ std::tuple<Tensor,optional<int64_t>> masked_select_batch_rule(
155155
return std::make_tuple(result, 0);
156156
}
157157

158+
Tensor addr_decomposition(
159+
const Tensor& self, const Tensor& vec1, const Tensor& vec2,
160+
const Scalar& beta, const Scalar& alpha) {
161+
162+
auto outer = alpha * vec1.unsqueeze(-1) * vec2.unsqueeze(-2);
163+
return self * beta + outer;
164+
}
165+
158166
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
159167
#define BINARY_POINTWISE2(op, overload) \
160168
VMAP_SUPPORT(#op"."#overload, BINARY_POINTWISE_BATCH_RULE(ATEN_FN2(op, overload)));
@@ -193,6 +201,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
193201
BINARY_SCALAR_2(add, Tensor, Scalar);
194202
POINTWISE_BOXED(addcdiv);
195203
POINTWISE_BOXED(addcmul);
204+
m.impl("addr", addr_decomposition);
196205
BINARY_POINTWISE(atan2);
197206
BINARY_SCALAR_2(bitwise_and, Tensor, Scalar);
198207
BINARY_POINTWISE2(bitwise_or, Tensor);

functorch/test/test_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,6 @@ def test_vmapjvp(self, device, dtype, op):
572572
@skipOps('TestOperators', 'test_vmapvjp_has_batch_rule', vmapvjp_fail.union({
573573
xfail('view_as_complex'),
574574
xfail('__getitem__'),
575-
xfail('addr'),
576575
xfail('cdist'),
577576
xfail('cholesky'),
578577
xfail('clamp'),

functorch/test/test_vmap.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3099,7 +3099,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
30993099

31003100
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
31013101
@skipOps('TestVmapOperatorsOpInfo', 'test_op_has_batch_rule', vmap_fail.union({
3102-
xfail('addr'),
31033102
xfail('cdist'),
31043103
xfail('complex'),
31053104
xfail('copysign'),

0 commit comments

Comments
 (0)