diff --git a/src/Zygote.jl b/src/Zygote.jl index 8a51b14fd..72b357fb9 100644 --- a/src/Zygote.jl +++ b/src/Zygote.jl @@ -3,11 +3,12 @@ module Zygote using LinearAlgebra, Statistics using LinearAlgebra: copytri!, AbstractTriangular +import ZygoteRules import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback, - literal_getproperty, literal_getfield + literal_getproperty, literal_getfield, unthunk_tangent using ChainRulesCore -using ChainRules: ChainRules, rrule, unthunk, canonicalize +using ChainRules: ChainRules, AbstractThunk, rrule, unthunk, canonicalize using IRTools using MacroTools, Requires using MacroTools: @forward diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 7c7de8655..473c6f17a 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -1,3 +1,14 @@ +# ToDo: Move some of this to ZygoteRules, or move unthunk_tangent for Tuple and NamedTuple from +# Zygote rules here? +function unthunk_tangent end +@inline unthunk_tangent(x::AbstractThunk) = wrap_chainrules_output(unthunk(x)) +@inline unthunk_tangent(x::NTuple{N,<:Number}) where N = x +@inline unthunk_tangent(x::AbstractArray{<:Number,N}) where N = x +@inline unthunk_tangent(x::AbstractArray) = map(unthunk_tangent, x) +unthunk_tangent(d::IdDict) = IdDict([unthunk_tangent(k) => unthunk_tangent(v) for (k, v) in d]) +@non_differentiable unthunk_tangent(::IdDict) + + struct ZygoteRuleConfig{CTX<:AContext} <: RuleConfig{Union{HasReverseMode,NoForwardsMode}} context::CTX end @@ -102,7 +113,6 @@ is_kwfunc(k, ::Type{<:NamedTuple}, f, args...) = k===Core.kwftype(f) Convert `x` from the differentials types ChainRules uses to the format Zygote uses internally. """ @inline wrap_chainrules_output(x) = x -@inline wrap_chainrules_output(x::AbstractThunk) = wrap_chainrules_output(unthunk(x)) # For now we are just not going to deal with thunks @inline wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x) # Zygote convention: even if many AbstractZero partials (i.e. multi-input function), make just 1 nothing. @inline wrap_chainrules_output(x::Tuple{Vararg{ChainRules.AbstractZero}}) = nothing diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index ee2a69528..ee6455225 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -37,7 +37,13 @@ end _pullback(f, args...) = _pullback(Context(), f, args...) tailmemaybe(::Nothing) = nothing -tailmemaybe(x::Tuple) = Base.tail(x) +tailmemaybe(x::Tuple) = unthunk_tangent(Base.tail(x)) + +# unthunking is essentially an identity operation on a lazy value, but +# `@adjoint unthunk_tangent(x) = unthunk_tangent(x), ȳ -> (ȳ,)` is not enough to make +# nested AD work, so define +@adjoint tailmemaybe(xs::Tuple) = tailmemaybe(xs), x̄s -> ((nothing, x̄s...),) + @inline pullback(f, args...) = pullback(f, Context(), args...) function pullback(f, cx::AContext, args...) @@ -376,7 +382,7 @@ function pullback(f, ps::Params) cache(cx)[p] = nothing end back(Δ) - Grads(cx.cache, ps) # TODO make a copy + Grads(unthunk_tangent(cx.cache), ps) # TODO make a copy end end diff --git a/src/compiler/reverse.jl b/src/compiler/reverse.jl index 532644914..0a43e7062 100644 --- a/src/compiler/reverse.jl +++ b/src/compiler/reverse.jl @@ -3,6 +3,18 @@ using IRTools: IR, Variable, Pipe, xcall, var, prewalk, postwalk, insertafter!, finish, expand!, prune!, substitute!, substitute, block, block!, branch!, return!, stmt, meta + +# TODO: Temporary, to be removed when ChainRulesCore rrules are required to +# support thunks as an input and all instances of _adjoint_keepthunks in +# Zygote have been replaces by rrules: +macro _adjoint_keepthunks(ex) + ZygoteRules.gradm(ex, false, true) +end +macro _adjoint_keepthunks!(ex) + ZygoteRules.gradm(ex, true, true) +end + + @inline tuple_va(N, xs) = xs @inline tuple_va(N, x, xs...) = (x, tuple_va(N, xs...)...) @inline tuple_va(::Val{N}, ::Nothing) where N = ntuple(_ -> nothing, Val(N)) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 58f7ecf99..fdd0cf1dc 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -49,7 +49,8 @@ function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArr Base.reducedim_initarray(A, region, nothing, Union{Nothing,eltype(A)}) end -function unbroadcast(x::AbstractArray, x̄) +function unbroadcast(x::AbstractArray, maybethunked_x̄) + x̄ = unthunk_tangent(maybethunked_x̄) N = ndims(x̄) if length(x) == length(x̄) _project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices, row vectors diff --git a/src/lib/lib.jl b/src/lib/lib.jl index 52a734809..de192b267 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -38,15 +38,15 @@ function accum(x::RefValue, y::RefValue) end # Core functions -@adjoint deepcopy(x) = deepcopy(x), ȳ -> (ȳ,) +@_adjoint_keepthunks deepcopy(x) = deepcopy(x), ȳ -> (ȳ,) -@adjoint (::Type{V})(x...) where V<:Val = V(x...), _ -> nothing +@_adjoint_keepthunks (::Type{V})(x...) where V<:Val = V(x...), _ -> nothing -@adjoint ifelse(cond::Bool, t, f) = +@_adjoint_keepthunks ifelse(cond::Bool, t, f) = ifelse(cond, t, f), Δ -> cond ? (nothing, Δ, zero(Δ)) : (nothing, zero(Δ), Δ) -@adjoint Base.typeassert(x, T) = Base.typeassert(x, T), Δ -> (Δ, nothing) +@_adjoint_keepthunks Base.typeassert(x, T) = Base.typeassert(x, T), Δ -> (Δ, nothing) accum_param(::Context{false}, _, Δ) = Δ @generated function accum_param(cx::Context, x, Δ) @@ -70,11 +70,11 @@ end unwrap(x) = x -@adjoint unwrap(x) = unwrap(x), x̄ -> (accum_param(__context__, x, x̄),) +@_adjoint_keepthunks unwrap(x) = unwrap(x), x̄ -> (accum_param(__context__, x, x̄),) unwrap(ref, x) = x -@adjoint unwrap(ref, x) = unwrap(x), function (x̄) +@_adjoint_keepthunks unwrap(ref, x) = unwrap(x), function (x̄) accum_global(__context__, ref, x̄) (accum_param(__context__, x, x̄),) end @@ -88,7 +88,7 @@ function global_set(ref, val) end end -@adjoint! function global_set(ref, x) +@_adjoint_keepthunks! function global_set(ref, x) global_set(ref, x), function (x̄) gs = cache(__context__) x̄ = accum(get(gs, ref, nothing), x̄) @@ -101,9 +101,9 @@ end using Base: tail -@adjoint tuple(xs...) = xs, identity +@_adjoint_keepthunks tuple(xs...) = xs, identity -@adjoint function literal_getindex(xs::NTuple{N,Any}, ::Val{i}) where {N,i} +@_adjoint_keepthunks function literal_getindex(xs::NTuple{N,Any}, ::Val{i}) where {N,i} val = xs[i] function back(Δ) accum_param(__context__, val, Δ) === nothing && return @@ -112,7 +112,7 @@ using Base: tail val, back end -@adjoint function getindex(xs::NTuple{N,Any}, i::Integer) where N +@_adjoint_keepthunks function getindex(xs::NTuple{N,Any}, i::Integer) where N val = xs[i] function back(Δ) accum_param(__context__, val, Δ) === nothing && return @@ -121,10 +121,10 @@ end return val, back end -@adjoint getindex(xs::NTuple{N,Any}, r::AbstractUnitRange) where N = +@_adjoint_keepthunks getindex(xs::NTuple{N,Any}, r::AbstractUnitRange) where N = (xs[r], Δ -> (ntuple(j -> j in r ? Δ[findfirst(isequal(j), r)] : nothing, Val(N)), nothing)) -@adjoint function getindex(xs::NTuple{N,Any}, r::AbstractVector) where N +@_adjoint_keepthunks function getindex(xs::NTuple{N,Any}, r::AbstractVector) where N val = xs[r] function back(Δ) dxs = ntuple(Val(length(xs))) do x @@ -155,18 +155,18 @@ function _pullback(cx::AContext, ::typeof(literal_indexed_iterate), xs::Tuple, : end # Needed for iteration lowering -@adjoint Core.getfield(xs::NTuple{N,Any}, i::Int) where N = +@_adjoint_keepthunks Core.getfield(xs::NTuple{N,Any}, i::Int) where N = (xs[i], Δ -> (ntuple(j -> i == j ? Δ : nothing, Val(N)), nothing)) -@adjoint Core.getfield(xs::NamedTuple{K,<:NTuple{N,Any}}, i::Int) where {K,N} = +@_adjoint_keepthunks Core.getfield(xs::NamedTuple{K,<:NTuple{N,Any}}, i::Int) where {K,N} = (xs[i], Δ -> (NamedTuple{K}(ntuple(j -> i == j ? Δ : nothing, Val(N))), nothing)) -@adjoint function Base.first(xs::Tuple) +@_adjoint_keepthunks function Base.first(xs::Tuple) drest = map(_->nothing, tail(xs)) first(xs), Δ -> ((Δ, drest...),) end -@adjoint Base.tail(xs::Tuple) = tail(xs), x̄s -> ((nothing, x̄s...),) +@_adjoint_keepthunks Base.tail(xs::Tuple) = tail(xs), x̄s -> ((nothing, x̄s...),) _empty(x) = length(x) _empty(x::Union{Tuple,NamedTuple}) = map(_->nothing, x) @@ -188,7 +188,7 @@ end unapply(t, xs) = _unapply(t, xs)[1] -@adjoint! function Core._apply(f, args...) +@_adjoint_keepthunks! function Core._apply(f, args...) y, back = Core._apply(_pullback, (__context__, f), args...) st = map(_empty, args) y, function (Δ) @@ -199,7 +199,7 @@ unapply(t, xs) = _unapply(t, xs)[1] end if VERSION >= v"1.4.0-DEV.304" - @adjoint! function Core._apply_iterate(::typeof(iterate), f, args...) + @_adjoint_keepthunks! function Core._apply_iterate(::typeof(iterate), f, args...) y, back = Core._apply(_pullback, (__context__, f), args...) st = map(_empty, args) y, function (Δ) @@ -225,7 +225,7 @@ end @generated pair(::Val{k}, v, _=nothing) where k = :($k = v,) @generated pair(::Val{k}, v, ::NamedTuple{keys}) where {k,keys} = k isa Int ? :($(getfield(keys, k)) = v,) : :($k = v,) -@adjoint function literal_getfield(x, ::Val{f}) where f +@_adjoint_keepthunks function literal_getfield(x, ::Val{f}) where f val = getfield(x, f) function back(Δ) accum_param(__context__, val, Δ) === nothing && return @@ -273,7 +273,7 @@ function grad_mut(cx::Context, x) end end -@adjoint! function setfield!(x, f, val) +@_adjoint_keepthunks! function setfield!(x, f, val) y = setfield!(x, f, val) g = grad_mut(__context__, x) y, function (_) @@ -289,13 +289,13 @@ end Jnew{T}(g) where T = Jnew{T,typeof(g)}(g) -@adjoint! function __new__(T, args...) +@_adjoint_keepthunks! function __new__(T, args...) x = __new__(T, args...) g = !ismutabletype(T) || fieldcount(T) == 0 ? nothing : grad_mut(__context__, x) x, Jnew{T,typeof(g),false}(g) end -@adjoint! function __splatnew__(T, args) +@_adjoint_keepthunks! function __splatnew__(T, args) x = __splatnew__(T, args) g = !ismutabletype(T) || fieldcount(T) == 0 ? nothing : grad_mut(__context__, x) x, Jnew{T,typeof(g),true}(g)