-
-
Notifications
You must be signed in to change notification settings - Fork 211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use zygote2differential to wrap chainrules inputs #1057
base: master
Are you sure you want to change the base?
Conversation
well this breaks some tests in weird ways. |
I just checked TuringLang/DistributionsAD.jl#198 (the CR1 version) locally and it still fails with the same error messages ("adjoint for constructor ..."), even with this PR. |
Yeah won't fix that. |
There was a matching differential2zygote that@mzgubic wrote. |
Ah sorry, I misunderstood your comment. Unfortunately, the example is not fixed either. |
Here it is, in case you find it useful: differential2legacy(x) = unthunk(x) # TODO eventually remove this
differential2legacy(::AbstractZero) = nothing
differential2legacy(t::Union{Tuple, NamedTuple}) = map(differential2legacy, t)
differential2legacy(::Nothing) = (legacytype_warn(Nothing); return nothing)
differential2legacy(a::AbstractArray) = differential2legacy.(a) # TODO: what to do with arrays with nothing?
differential2legacy(a::AbstractArray{<:Number}) = a
for T_outer in (:Tuple, :NamedTuple)
# we create separate methods rather than using a `Union` + an `if` so that we avoid a
# branch that changes output type, because nested AD on that kinda thing makes Zygote less
# than happy.
@eval @inline function differential2legacy(x::Composite{P, T}) where {P, T<:$T_outer}
xp = map(differential2legacy, canonicalize(x))
convert($T_outer, xp)
end
end I do recall getting into some kind of trouble when using this instead of |
@mzgubic implemented zygote2differential as a better version of wrap_chainrules_inputs and added it to use in the code for
rrule_via_ad
.But it was not added to the normal path for when Zygote uses ChainRules.
I guess because it requires keeping the primal values in memory.
Which is probably a lot?
Anyway this would give us more consistent chainrules types.
No more
Tangent{Any}
ornothings
that are hidden with-in arrays.We probably do not want to merge this as is because of the extra memory use.
or maybe it is not too bad. Do we have a benchmark for it?
But hopefully this will fix the problems in TuringLang/DistributionsAD.jl#197
cc @devmotion .
If it does we can look at reworking
zygote2differential
to not have to store so much.We learnt a lot about doing that for
ProjectTo
same techniques can be applied here.
NB: I am putting this PR up at 9:30 at night, and I have not even run it locally.
Might have typos etc and just not work.
It also has no tests, yet.