-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_differentiation.chpl
81 lines (59 loc) · 1.95 KB
/
test_differentiation.chpl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
use UnitTest;
use ForwardModeAD;
type D = [0..#2] multidual;
type D2 = [0..#2] dual;
proc testUnivariateFunctions(test: borrowed Test) throws {
proc f(x) {
return x ** 2 + 2 * x + 1;
}
var df = derivative(proc(x : dual) {return f(x);}, 1);
test.assertEqual(df, 4.0);
var valder = f(initdual(1));
test.assertEqual(value(valder), 4.0);
test.assertEqual(derivative(valder), 4.0);
}
proc testGradient(test: borrowed Test) throws {
proc g(x) {
return 2.0;
}
var dg = gradient(proc(x : D) {return g(x);}, [1.0, 2.0]);
test.assertEqual(dg, [0.0, 0.0]);
proc h(x) {
return x[0] ** 2 + 3 * x[0] * x[1];
}
var valgradh = h(initdual([1, 2]));
test.assertEqual(value(valgradh), 7);
test.assertEqual(gradient(valgradh), [8.0, 3.0]);
}
proc testJacobian(test: borrowed Test) throws {
proc F(x) {
return [x[0] ** 2 + x[1] + 1, x[0] + x[1] ** 2 + x[0] * x[1]];
}
var valjac = F(initdual([1.0, 2.0])),
_jac: [0..1, 0..1] real = ((2.0, 1.0), (3.0, 5.0));
test.assertEqual(value(valjac), [4.0, 7.0]);
test.assertEqual(jacobian(valjac), _jac);
proc G(x) {return [1, 2, 3];}
var Jg = jacobian(proc(x : D) {return G(x);}, [1.0, 2.0]),
_Jg: [0..2, 0..1] real;
test.assertEqual(Jg, _Jg);
}
proc testDirectionalAndJvp(test: borrowed Test) throws {
proc f(x) {
return x[0] ** 2 + 3 * x[0] * x[1];
}
var valdirder = f(initdual([1, 2], [0.5, 2.0]));
test.assertEqual(value(valdirder), 7);
test.assertEqual(directionalDerivative(valdirder), 10);
var dirder = directionalDerivative(proc(x: D2) {return f(x);}, [1, 2], [0.5, 2.0]);
test.assertEqual(dirder, 10);
proc F(x) {
return [x[0] ** 2 + x[1] + 1, x[0] + x[1] ** 2 + x[0] * x[1]];
}
var valjvp = F(initdual([1, 2], [0.5, 2.0]));
test.assertEqual(value(valjvp), [4.0, 7.0]);
test.assertEqual(jvp(valjvp), [3.0, 11.5]);
var Jv = jvp(proc(x: D2) {return F(x);}, [1, 2], [0.5, 2.0]);
test.assertEqual(Jv, [3.0, 11.5]);
}
UnitTest.main();