Skip to content

Conversation

@MasonProtter
Copy link
Contributor

Attempt to fix the problems raised in SciML/NeuralPDE.jl#791 that was caused by #770

I think this is the right way to generalize the scalar rule to any input type, and always return a ZeroTangent() instead of a NoTangent() like the old @non_differentiable rule did which was quite mathematically suspect.

@MasonProtter
Copy link
Contributor Author

MasonProtter commented Jan 24, 2024

Appears to have the same failure profile as #774. @oxinabox any thoughts on what we should do? If there's problems in NeuralPDE, there's probably also issues in other parts of the ecosystem.

Here's a demo of these methods fixing the problem that NeuralPDE had: SciML/NeuralPDE.jl#792

@oxinabox
Copy link
Member

oxinabox commented Jan 25, 2024

Why is this not:

function frule((_, _), ::typeof(zero), x)
    return (zero(x), ZeroTangent())
end

function rrule(::typeof(zero), x)
    zero_pullback(_) = (NoTangent(), ZeroTangent())
    return (zero(x), zero_pullback)
end

# `one`

function frule((_, _), ::typeof(one), x)
    return (one(x), ZeroTangent())
end

function rrule(::typeof(one), x)
    one_pullback(_) = (NoTangent(), ZeroTangent())
    return (one(x), one_pullback)
end

This extra stuff with projection and multiplication all should just cancel away for if the tangent is ZeroTangent().

I agree the NoTangent before was mathematically wrong, and correct is ZeroTangent(), though because of how they act similarly, and under Zygote identically it worked out fine.

We should probably add a test for this since it has caused bugs now.
Should be very simple taking the zero([1,2,3])

Thank you for playing wack-a-mole with this.

@MasonProtter
Copy link
Contributor Author

MasonProtter commented Jan 25, 2024

Why is this not:

Mostly because I was copy-pasting and then tweaking the output from @scalar_rule, I can simplify it though.

We should probably add a test for this since it has caused bugs now.
Should be very simple taking the zero([1,2,3])

can do

@oxinabox
Copy link
Member

oxinabox commented Jan 25, 2024

test failures on 1.x are unrelated.
1.6 is passing which is enough

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants