-
Notifications
You must be signed in to change notification settings - Fork 103
/
Copy pathtest_pythonkey.py
677 lines (560 loc) · 22.8 KB
/
test_pythonkey.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from torch.testing._internal.common_utils import TestCase, run_tests
import torch
import torch.nn as nn
import torch.utils._pytree as pytree
import unittest
import warnings
import itertools
from functools import partial
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from functorch import (
grad, vjp, vmap, jacrev,
make_fx
)
from functorch._src.aot_autograd import aot_module_simplified
from functorch.compile import (
nnc_jit, compiled_function, compiled_module,
min_cut_rematerialization_partition, aot_function, aot_module, decomposition_table, nop,
num_of_recompilations, default_partition, default_decompositions
)
from torch.testing._internal.common_device_type import ops
from functorch_lagging_op_db import functorch_lagging_op_db
from functorch_additional_op_db import additional_op_db
from common_utils import (
xfail,
skip,
skipOps,
)
USE_TORCHVISION = False
try:
import torchvision
USE_TORCHVISION = True
except ImportError:
warnings.warn("Couldn't import torchvision. Some of our tests use it, try "
"to install it with commands from pytorch.org, post-fixed with "
"`--no-deps` to avoid overwriting the pytorch installation",
UserWarning)
USE_NETWORKX = False
try:
import networkx # noqa: F401
USE_NETWORKX = True
except ImportError:
warnings.warn("Some tests use networkx but it was not installed",
UserWarning)
# NB: numpy is a testing dependency!
class TestPythonKey(TestCase):
def test_make_fx(self, device):
def f(x):
return torch.sin(x)
inp = torch.randn(3)
fx_f = make_fx(f)(inp)
new_inp = torch.randn(3)
self.assertEqual(fx_f(new_inp), f(new_inp))
def test_make_fx_grad(self, device):
def f(x):
return torch.sin(x).sum()
inp = torch.randn(3)
f = grad(f)
fx_f = make_fx(f)(inp)
new_inp = torch.randn(3)
self.assertEqual(fx_f(new_inp), f(new_inp))
def test_scalar_device(self, device):
def f(a, b):
return a + b
inps = [torch.randn(3, device=device), torch.tensor(5)]
fx_f = make_fx(f)(*inps)
self.assertEqual(fx_f(*inps), f(*inps))
def test_make_fx_vmap(self, device):
def f(x):
return torch.sin(x)
inp = torch.randn(5, 3)
f = vmap(f)
fx_f = make_fx(f)(inp)
new_inp = torch.randn(5, 3)
self.assertEqual(fx_f(new_inp), f(new_inp))
def test_make_fx_jacrev(self, device):
def f(x):
return x.sin().sum()
inp = torch.randn(3)
f = jacrev(jacrev(f))
fx_f = make_fx(f)(inp)
new_inp = torch.randn(3)
self.assertEqual(fx_f(new_inp), f(new_inp))
def test_make_fx_vjp(self, device):
def f(x):
return torch.sin(x).sum()
primals = torch.randn(3)
_, vjp_fn = vjp(f, primals)
cotangent = torch.randn(())
fx_f = make_fx(vjp_fn)(cotangent, True, True)
new_cotangent = torch.randn(())
self.assertEqual(fx_f(new_cotangent, True, True), vjp_fn(new_cotangent))
def test_make_fx_no_decompose(self, device):
# FIXME
return self.skipTest("error: maximum recursion reached")
def f(x):
return torch.tanh(x).sum()
fx_f = make_fx(grad(f))(torch.randn(5))
ops = set([i.target for i in fx_f.graph.nodes])
self.assertEqual(torch.ops.aten.tanh_backward in ops, True)
fx_f = make_fx(grad(f), decomposition_table)(torch.randn(5))
ops = set([i.target for i in fx_f.graph.nodes])
self.assertEqual(torch.ops.aten.tanh_backward in ops, False)
def test_nnc_jit(self, device):
def f(x):
return torch.sin(x)
jit_f = nnc_jit(f)
inp = torch.randn(3)
self.assertEqual(jit_f(inp), f(inp))
def test_nnc_scalar(self, device):
def f(x):
return torch.sin(x)
jit_f = nnc_jit(f)
inp = torch.randn(())
self.assertEqual(jit_f(inp), f(inp))
def test_nnc_pytrees(self, device):
def f(x):
return [torch.sin(x[0])]
jit_f = nnc_jit(f)
inp = [torch.randn(3)]
self.assertEqual(jit_f(inp), f(inp))
def test_external_calls(self, device):
def f(a, b):
return torch.mv(a, b)
jit_f = nnc_jit(f)
inp = [torch.randn(3, 3), torch.randn(3)]
self.assertEqual(jit_f(*inp), f(*inp))
def test_nnc_passthrough(self, device):
def f(x, y):
return x + y, y
inp = (torch.randn(3), torch.randn(3))
jit_f = nnc_jit(f)
self.assertEqual(jit_f(*inp), f(*inp))
def f(x):
x['a'] = x['a'] * 2
return x
inp = ({'a': torch.randn(3), 'b': torch.randn(3)},)
jit_f = nnc_jit(f)
self.assertEqual(jit_f(*inp), f(*inp))
@unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
def test_resnet18_backward_trace(self, device):
mod = torchvision.models.resnet18()
def f(x):
out = mod(x)
out.sum().backward()
return [a.grad for a in mod.parameters()]
inp = torch.randn(3, 3, 250, 250, requires_grad=True)
grads = f(inp)
mod.zero_grad()
mod(inp).sum().backward()
grads2 = [a.grad for a in mod.parameters()]
self.assertEqual(grads, grads2)
make_fx_failures = {
xfail('allclose'),
xfail('nn.functional.dropout'),
xfail('linalg.eigvals'),
xfail('nn.functional.max_pool1d', device_type='cpu'), # precision problems?
xfail('randn_like'), # randomness
xfail('rand_like'), # randomness
xfail('randint_like'), # randomness
skip('new_empty'), # nondeterministic
skip('empty_like'), # nondeterministic
skip('linalg.lstsq', 'grad_oriented'), # flaky
xfail('normal', '', device_type='cpu'),
xfail('normal', 'number_mean', device_type='cpu'),
xfail('multinomial', device_type='cpu'),
xfail('nn.functional.feature_alpha_dropout', 'with_train', device_type='cpu'),
xfail('bernoulli', device_type='cpu'),
xfail('nn.functional.dropout2d', device_type='cpu'),
skip('nn.functional.max_unpool1d', '', device_type='cpu'), # flaky
skip('nn.functional.max_unpool2d', '', device_type='cpu'), # flaky
skip('nn.functional.max_unpool3d', '', device_type='cpu'), # flaky
skip('linalg.lstsq'), # flaky, probably just a precision issue
xfail('histogram'),
xfail('scatter')
}
class TestPythonKeyOperatorsOpInfo(TestCase):
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
@skipOps('TestPythonKeyOperatorsOpInfo', 'test_make_fx_exhaustive', make_fx_failures
)
def test_make_fx_exhaustive(self, device, dtype, op):
def f(args, kwargs):
return op.op(*args, **kwargs)
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
new_f = None
for sample_input in sample_inputs_itr:
args = [sample_input.input] + list(sample_input.args)
kwargs = sample_input.kwargs
new_f = make_fx(f)(args, kwargs)
for arg in args:
if isinstance(arg, torch.Tensor) and arg.dtype == torch.float:
arg.uniform_(0, 1)
try:
old_out = f(args, kwargs)
except Exception:
continue
new_out = new_f(args, kwargs)
self.assertEqual(new_out, old_out)
pass
def _outs_and_grads(fn, inps):
outs = fn(*inps)
def get_diff_tensors(tensors):
diff_tensors = []
for tensor in pytree.tree_flatten(tensors)[0]:
if isinstance(tensor, torch.Tensor) and tensor.requires_grad:
diff_tensors.append(tensor)
return diff_tensors
def full_reduce(outs_):
res = 0
for out in outs_:
res=res+out.sum()
return res
diff_inps = get_diff_tensors(inps)
diff_outs = get_diff_tensors(outs)
assert len(diff_outs) > 0
assert len(diff_inps) > 0
grads = torch.autograd.grad(full_reduce(diff_outs), diff_inps, allow_unused=True)
return outs, grads
def _outs_and_grads_and_grad_grads(fn, inps):
outs = fn(*inps)
diff_outs = []
diff_inps = []
for out in pytree.tree_flatten(outs)[0]:
if isinstance(out, torch.Tensor) and out.requires_grad:
diff_outs.append(out)
for inp in pytree.tree_flatten(inps)[0]:
if isinstance(inp, torch.Tensor) and inp.requires_grad:
diff_inps.append(inp)
def full_reduce(outs):
res = 0
for out in outs:
res=res+out.sum()
return res
assert len(diff_outs) > 0
assert len(diff_inps) > 0
grads = torch.autograd.grad(diff_outs, diff_inps, create_graph=True)
diff_grads = []
for grad_ in grads:
if isinstance(grad_, torch.Tensor) and grad_.requires_grad:
diff_grads.append(grad_)
assert len(diff_grads) > 0
grad_grads = torch.autograd.grad(diff_grads, diff_inps)
return outs, grads, grad_grads
class TestAOTAutograd(TestCase):
def verify_aot_autograd(self, f, inp):
if isinstance(f, nn.Module):
compiled_f = aot_module(f, nop)
else:
compiled_f = aot_function(f, nop)
ref_out, ref_grad = _outs_and_grads(f, inp)
test_out, test_grad = _outs_and_grads(compiled_f, inp)
self.assertEqual(ref_out, test_out)
self.assertEqual(ref_grad, test_grad)
def verify_aot_autograd_with_double_backward(self, f, inp):
if isinstance(f, nn.Module):
compiled_f = aot_module(f, nop)
else:
compiled_f = aot_function(f, nop, partition_fn=min_cut_rematerialization_partition)
ref_out, ref_grad, ref_grad_grad = _outs_and_grads_and_grad_grads(f, inp)
test_out, test_grad, test_grad_grad = _outs_and_grads_and_grad_grads(compiled_f, inp)
self.assertEqual(ref_out, test_out)
self.assertEqual(ref_grad, test_grad)
self.assertEqual(ref_grad_grad, test_grad_grad)
def test_single_output(self):
def f(a, b):
return a + b
inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)]
self.verify_aot_autograd(f, inp)
def test_multi_output(self):
def f(a, b):
return a + b, a - b
inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)]
self.verify_aot_autograd(f, inp)
def test_multi_output_list(self):
def f(a, b):
return [a + b, a - b]
inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)]
self.verify_aot_autograd(f, inp)
def test_sin_bla(self):
def f(a):
return torch.sin(a)
inp = [torch.tensor(2.3, requires_grad=True)]
self.verify_aot_autograd_with_double_backward(f, inp)
# self.verify_aot_autograd(f, inp)
def test_no_grad_input_output(self):
def f(a, b):
return a.cos(), b.cos(), a * b
inp_thunks = [lambda: torch.randn(5, requires_grad=True), lambda: torch.randn(5, requires_grad=False)]
for inps in itertools.product(inp_thunks, repeat=2):
inps = [i() for i in inps]
# ignore the case when both inputs don't require grad
if inps[0].requires_grad or inps[1].requires_grad:
self.verify_aot_autograd(f, inps)
# fails
# def test_inner_grad(self):
# def foo(x):
# y = torch.exp(x)
# z = torch.autograd.grad(y, x, create_graph=True)
# return z
# inps = [torch.randn((), requires_grad=True)]
# self.verify_aot_autograd(foo, inps)
def test_grad_context(self):
def foo(x):
return x * 2
inps = [torch.randn((), requires_grad=True)]
graph_size = None
def assert_graph_empty(fx_g, _):
nonlocal graph_size
graph_size = len(fx_g.graph.nodes)
return fx_g
start_recompilations = num_of_recompilations()
f = aot_function(foo, nop, assert_graph_empty)
with torch.set_grad_enabled(False):
f(*inps)
with torch.set_grad_enabled(True):
f(*inps)
self.assertEqual(num_of_recompilations() - start_recompilations, 2)
def test_output_dict(self):
def f(x):
return {'a': x, 'b': x}
inp = [torch.randn(3, 3, requires_grad=True)]
self.verify_aot_autograd(f, inp)
def f(x, y):
return {'a': x, 'b': y + x}
inp = [torch.randn(3, requires_grad=True), torch.randn(3)]
self.verify_aot_autograd(f, inp)
def f(x):
new_d = {}
for k in x:
new_d[k] = x[k] * 2
return new_d
inp = [{'a': torch.randn(3, requires_grad=True), 'b': torch.randn(3, requires_grad=True)}]
self.verify_aot_autograd(f, inp)
def test_module(self):
mod = nn.Sequential(nn.Linear(32, 32), nn.ReLU())
compiled_mod = compiled_module(mod, nop, nop)
inp = torch.randn(32, 32)
ref_out = mod(inp)
ref_out.sum().backward()
ref_grads = sorted([(name, p.grad) for name, p in mod.named_parameters()])
out = compiled_mod(inp)
out.sum().backward()
grads = sorted([(name, p.grad) for name, p in mod.named_parameters()])
self.assertEqual((out, grads), (ref_out, ref_grads))
def test_batchnorm(self):
mod = compiled_module(nn.BatchNorm2d(4), nop, nop)
x = torch.ones(1, 4, 2, 2)
mod(x).sum().backward()
class TestEagerFusionOpInfo(TestCase):
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
# entries in here need don't work and need to be fixed.
# Each one of these is a bug (or needs to be investigated)
@skipOps('TestEagerFusionOpInfo', 'test_aot_autograd_exhaustive', {
xfail('linalg.cholesky'),
xfail('nn.functional.dropout'),
xfail('polar'),
xfail('to_sparse'),
xfail('addcdiv'),
xfail('cholesky'),
xfail('cumulative_trapezoid'),
xfail('diag_embed'),
xfail('linalg.householder_product'),
xfail('logit'),
xfail('matrix_exp'),
xfail('trapezoid'),
xfail('trapz'),
skip('linalg.svdvals'),
skip('linalg.eigvals'),
skip('linalg.det'), # fails
skip('linalg.cond'),
skip('t'),
skip('ldexp'),
})
def test_aot_autograd_exhaustive(self, device, dtype, op):
def f(args, kwargs):
return op.op(*args, **kwargs)
if not op.supports_autograd:
return
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True)
i = -1
for sample_input in sample_inputs_itr:
i+=1
if i == 0:
continue
print("SAMPLE INPUT: ", sample_input)
args = [sample_input.input] + list(sample_input.args)
kwargs = sample_input.kwargs
if not all([isinstance(i, torch.Tensor) and i.dtype == torch.float for i in args]):
self.skipTest("not all inputs are float tensors")
if not all([isinstance(i, torch.Tensor) and i.dtype == torch.float for i in kwargs.values()]):
self.skipTest("not all inputs are float tensors")
continue
t = f(args, kwargs)
if isinstance(t, tuple):
self.skipTest("output is a tuple")
continue
def reset_grads():
def f(x):
x.grad = None
pytree.tree_map(f, args)
def get_grads(args):
return pytree.tree_map(lambda x: x.grad, args)
compiled_f = compiled_function(f, nop, nop)
reset_grads()
compiled_f(args, kwargs).sum().backward()
compiled_grad = get_grads(args)
reset_grads()
f(args, kwargs).sum().backward()
orig_grad = get_grads(args)
self.assertEqual(orig_grad, compiled_grad)
# def create_new_arg(x):
# return x.detach().uniform_(0, 1).requires_grad_(x.requires_grad)
# args = pytree.tree_map(create_new_arg, args)
# reset_grads()
# compiled_f(args, kwargs).sum().backward()
# compiled_grad = get_grads(args)
# reset_grads()
# f(args, kwargs).sum().backward()
# orig_grad = get_grads(args)
# self.assertEqual(orig_grad, compiled_grad)
def extract_graph(fx_g, _, graph_cell):
graph_cell[0] = fx_g
return fx_g
def get_ins_outs(fx_g):
ins = []
outs = []
for n in fx_g.graph.nodes:
if n.op == 'placeholder':
ins.append(n)
elif n.op == 'output':
outs = tuple(n.args[0])
return ins, outs
def get_num_ins_outs(fx_g):
return tuple(len(i) for i in get_ins_outs(fx_g))
def get_fw_bw_graph(f, inps, partitioner=min_cut_rematerialization_partition):
fw_graph_cell = [None]
bw_graph_cell = [None]
aot_function(f,
fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell),
bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell),
partition_fn=partitioner,
decompositions=default_decompositions)(*inps).sum().backward()
return (fw_graph_cell[0], bw_graph_cell[0])
class TestPartitioning(TestCase):
@unittest.skipIf(not USE_NETWORKX, "networkx not available")
def test_recompute_partitioning(self):
def fn(a, b):
return torch.sin(torch.sin(a)) + b
# Reference calculation
ref_a = torch.rand(10, 10, requires_grad=True)
ref_b = torch.rand(10, 10, requires_grad=True)
ref = fn(ref_a, ref_b)
ref.sum().backward()
# Compiled function calculation
res_a = ref_a.clone().detach().requires_grad_(True)
res_b = ref_b.clone().detach().requires_grad_(True)
def compile_fn(x, _):
return x
compiled_fn = compiled_function(fn, compile_fn, compile_fn, min_cut_rematerialization_partition)
res = compiled_fn(res_a, res_b)
res.sum().backward()
assert torch.allclose(ref, res, atol=1e-3, rtol=1e-3)
assert torch.allclose(ref_a.grad, res_a.grad, atol=1e-3, rtol=1e-3)
assert torch.allclose(ref_b.grad, res_b.grad, atol=1e-3, rtol=1e-3)
def test_meta_tensor_inplace_op(self):
# Following module results in inplace ops while tracing. The test checks
# that the meta tensor information is stored for inplace ops.
class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.nn.Parameter(torch.randn(3072, 768, requires_grad=True))
self.bias = torch.nn.Parameter(torch.randn(3072, requires_grad=True))
def forward(self, add_4):
linear_4 = torch.nn.functional.linear(add_4, self.weight, bias=self.bias)
gelu = torch.nn.functional.gelu(linear_4)
return gelu
def check_meta_tensor(fx_g, _):
for node in fx_g.graph.nodes:
if node.op != 'output':
assert 'tensor_meta' in node.meta
return fx_g
inp0 = torch.randn(16, 128, 768, requires_grad=True)
inputs = [inp0, ]
mod = MockModule().to(device="cpu")
aot_mod = aot_module(mod, fw_compiler=check_meta_tensor)
aot_mod(*inputs)
def test_default_partitioner_getitem(self):
mod = nn.LayerNorm([10])
def f(x, mod_weight, mod_bias):
return torch.nn.functional.layer_norm(x, [10], mod_weight, mod_bias, eps=1e-6)
fw_graph, bw_graph = get_fw_bw_graph(f, [torch.randn(3, 10, requires_grad=True), mod.weight, mod.bias],
partitioner=default_partition)
self.assertEqual(get_num_ins_outs(fw_graph), (3, 7))
self.assertEqual(get_num_ins_outs(bw_graph), (12, 6))
@unittest.skipIf(not USE_NETWORKX, "networkx not available")
def test_min_cut_partitioner(self):
def f(x):
return x.cos().cos().cos()
fw_graph, bw_graph = get_fw_bw_graph(f, [torch.randn(3, requires_grad=True)])
self.assertEqual(get_num_ins_outs(fw_graph), (1, 2))
self.assertEqual(get_num_ins_outs(bw_graph), (3, 1))
def f(a, b, c, d):
x = a + b + c + d
return x.cos().cos()
fw_graph, bw_graph = get_fw_bw_graph(f, [torch.randn(3, requires_grad=True) for _ in range(4)])
self.assertEqual(get_num_ins_outs(fw_graph), (4, 2))
self.assertEqual(get_num_ins_outs(bw_graph), (3, 4))
def f(x):
return torch.mm(x, torch.ones(x.shape)).tanh().tanh()
fw_graph, bw_graph = get_fw_bw_graph(f, [torch.randn(5, 5, requires_grad=True)])
self.assertEqual(get_num_ins_outs(fw_graph), (1, 2))
ins, outs = get_ins_outs(fw_graph)
self.assertEqual(outs[1].target, torch.ops.aten.mm)
class TestContiguous(TestCase):
def test_contiguous(self):
# The test simulates the condition where transpose followed by view
# happens in the backward pass.
# https://discuss.pytorch.org/t/error-on-transpose-and-view/434
def f(x):
return x.view(2, 3).t()
inp = torch.randn(6, requires_grad=True)
out = aot_function(f, nop)(inp)
torch.autograd.grad(out, inp, torch.randn(3, 2))
class TestAOTModuleSimplified(TestCase):
def test_aot_module_simplified(self):
class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(20, 30)
def forward(self, x, y):
return (self.linear(x) + y, )
mod = MockModule()
mod.zero_grad()
x = torch.randn(128, 20, requires_grad=True)
y = torch.randn(128, 30, requires_grad=True)
inputs = [x, y]
cloned_inputs = [x.detach().clone().requires_grad_(True) for x in inputs]
ref = mod(*inputs)
ref[0].sum().backward()
aot_mod = aot_module_simplified(mod, nop)
aot_mod.zero_grad()
res = aot_mod(*cloned_inputs)
res[0].sum().backward()
assert torch.allclose(ref[0], res[0])
assert torch.allclose(inputs[0].grad, cloned_inputs[0].grad)
assert torch.allclose(inputs[1].grad, cloned_inputs[1].grad)
only_for = ("cpu")
instantiate_device_type_tests(
TestPythonKey,
globals(),
only_for=only_for,
)
instantiate_device_type_tests(TestPythonKeyOperatorsOpInfo, globals(), only_for=only_for)
instantiate_device_type_tests(TestEagerFusionOpInfo, globals(), only_for=only_for)
if __name__ == '__main__':
run_tests()