Skip to content

Commit 43bd90b

Browse files
authored
Pass keyword arguments to test_approx when checking thunks (#247)
* Pass keyword arguments to `test_approx` when checking chunks * Update Project.toml * Fix tests
1 parent 5f523e9 commit 43bd90b

File tree

3 files changed

+6
-3
lines changed

3 files changed

+6
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesTestUtils"
22
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3-
version = "1.7.0"
3+
version = "1.7.1"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/testers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ function test_rrule(
227227
end
228228

229229
if check_thunked_output_tangent
230-
test_approx(ad_cotangents, pullback(@thunk(ȳ)), "pulling back a thunk:")
230+
test_approx(ad_cotangents, pullback(@thunk(ȳ)), "pulling back a thunk:"; isapprox_kwargs...)
231231
check_inferred && _test_inferred(pullback, @thunk(ȳ))
232232
end
233233
end # top-level testset

test/testers.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,10 @@ end
683683
function ChainRulesCore.rrule(::typeof(my_id), x)
684684
my_id_pb(ȳ) = (NoTangent(), ȳ)
685685
function my_id_pb(ȳ::AbstractThunk)
686-
precision = rand() > 0.5 ? Float64 : Float32
686+
# We use a condition that always evaluates to true to avoid issues with tolerances
687+
# (see https://github.com/JuliaDiff/ChainRulesTestUtils.jl/pull/247)
688+
# The function is type unstable for `Float64` inputs nevertheless
689+
precision = rand() >= 0.0 ? Float64 : Float32
687690
return (NoTangent(), precision(unthunk(ȳ)))
688691
end
689692
return x, my_id_pb

0 commit comments

Comments
 (0)