Skip to content

Commit

Permalink
Allow shared parameters, take III (#106)
Browse files Browse the repository at this point in the history
* allow shared parameters, take III

Co-authored-by: Brian Chen <[email protected]>

* one more dict to allow artificial ties

* a tidier idea, just replace _default_walk

* add a LeafCache type, to make fmap ignore () singleton

* remove leaf.frozen field

* eager accumulation

* give up on customising fmap & write the recursion, add evil tests

* add ismutable check

* docs etc

* fix doctests

* group the tests

Co-authored-by: Brian Chen <[email protected]>
  • Loading branch information
mcabbott and ToucheSir authored Oct 13, 2022
1 parent bf54f76 commit 9c12e5d
Show file tree
Hide file tree
Showing 6 changed files with 352 additions and 93 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Optimisers"
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
authors = ["Mike J Innes <[email protected]>"]
version = "0.2.9"
version = "0.2.10"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -12,7 +12,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "1"
Functors = "0.2.8, 0.3"
Functors = "0.3"
Zygote = "0.6.40"
julia = "1.6"

Expand Down
66 changes: 59 additions & 7 deletions docs/src/index.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Optimisers.jl

## Defining an optimisation rule
## An optimisation rule

A new optimiser must overload two functions, [`apply!`](@ref) and [`init`](@ref).
These act on one array of parameters:
Expand Down Expand Up @@ -60,18 +60,18 @@ Notice that a completely new instance of the model is returned. Internally, this
is handled by [Functors.jl](https://fluxml.ai/Functors.jl), where we do a walk over the
tree formed by the model and update the parameters using the gradients.

There is also [`Optimisers.update!`](@ref) which similarly returns a new model and new state,
but is free to mutate arrays within the old one for efficiency.
The method of `apply!` for each rule is likewise free to mutate arrays within its state;
they are defensively copied when this rule is used with `update`.

Optimisers.jl does not depend on any one automatic differentiation package,
but for now the most likely source of gradients is [Zygote.jl](https://fluxml.ai/Zygote.jl).
Note that `update` always wants the gradient from Zygote's "explicit" mode, as shown above.
This `∇model` is another tree structure, rather than the dictionary-like object from
Zygote's "implicit" mode `gradient(() -> loss(...), Flux.params(model))` -- see
[Zygote's documentation](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1) for more about this difference.

There is also [`Optimisers.update!`](@ref) which similarly returns a new model and new state,
but is free to mutate arrays within the old one for efficiency.
The method of `apply!` you write is likewise free to mutate arrays within its state;
they are defensively copied when this rule is used with `update`.

## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl)

The main design difference of Lux is that the tree of parameters is separate from
Expand Down Expand Up @@ -110,6 +110,57 @@ Besides the parameters stored in `params` and gradually optimised, any other mod
is stored in `lux_state`. For simplicity this example does not show how to propagate the
updated `lux_state` to the next iteration, see Lux's documentation.

## Non-`trainable` Parameters

Optimisers.jl uses [Functors.jl](https://fluxml.ai/Functors.jl) to walk the `struct`s
making up the model, for which they must be annotated `@functor Type`.
By default optimisation will alter all [`isnumeric`](@ref) arrays.

If some arrays of a particular layer should not be treated this way,
you can define a method for [`trainable`](@ref)

```julia
struct Layer{T}
alpha::T
beta::T
length::Int
end
Layer(n::Int) = Layer(randn(n), zeros(n), n)

Functors.@functor Layer

# Both array fields will be, for example, moved to the GPU:
Functors.children(Layer(3)) # (alpha = [...], beta = [...], length)

Optimisers.trainable(x::Layer) = (; alpha = x.alpha) # must be a subset of chidlren

# Only the first field will be optimised:
st = Optimisers.setup(DecayDescent(0.1), Layer(3))
```

## Tied Parameters

If the same array appears twice (or more) in the model, [Functors.jl](https://fluxml.ai/Functors.jl) should recognise this.
Within Optimisers.jl, `setup` will initialise once, and use the same `Leaf` for both parameters.
Then `update` will accumulate the gradient from both, and the updated model returned will have the tie maintained.

```julia
using Flux, Optimisers

enc = Chain(Dense(40 => 20, tanh), Dense(20 => 10));
dec = Chain(Dense(enc[1].weight', true, tanh), Dense(enc[2].weight', true, tanh));
model = Chain(; enc, dec)

st = Optimisers.setup(Optimisers.Adam(), model);

st.layers.enc.layers[1].weight === st.layers.dec.layers[1].weight.parent # true
```

This identification relies on `===`, and will work for ordinary `Array`s and `CuArray`s.
It will not at present work for `reshape`d arrays, nor for immutable arrays such as those
from StaticArrays.jl.


## Obtaining a flat parameter vector

Instead of a nested tree-like structure, sometimes is is convenient to have all the
Expand Down Expand Up @@ -143,10 +194,11 @@ st, flat = Optimisers.update(st, flat, ∇flat)
```

Here `flat` contains only the 283 trainable parameters, while the non-trainable
ones are preserved inside `re`.
ones are preserved inside `re`, an object of type `Restructure`.
When defining new layers, these can be specified if necessary by overloading [`trainable`](@ref).
By default, all numeric arrays visible to [Functors.jl](https://github.com/FluxML/Functors.jl)
are assumed to contain trainable parameters.
Tied parameters (arrays appearing in different layers) are included only once in `flat`.

Lux stores only the trainable parameters in `params`.
This can also be flattened to a plain `Vector` in the same way:
Expand Down
24 changes: 16 additions & 8 deletions src/Optimisers.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module Optimisers

using Functors: functor, fmap, isleaf
using Functors: functor, fmap, isleaf, @functor, fmapstructure, children
using LinearAlgebra

include("interface.jl")
Expand All @@ -16,6 +16,10 @@ export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief,
WeightDecay, ClipGrad, ClipNorm, OptimiserChain

###
### one-array functions
###

"""
Optimisers.apply!(rule::RuleType, state, parameters, gradient) -> (state, gradient)
Expand Down Expand Up @@ -57,6 +61,10 @@ julia> Optimisers.init(Momentum(), [1.0, 2.0])
"""
init

###
### whole-model functions
###

"""
Optimisers.setup(rule, model) -> tree
Expand All @@ -69,7 +77,7 @@ or [`update!`](@ref).
julia> m = (x = rand(3), y = (true, false), z = tanh);
julia> Optimisers.setup(Momentum(), m) # same field names as m
(x = Leaf(Momentum{Float32}(0.01, 0.9), [0.0, 0.0, 0.0]), y = (nothing, nothing), z = nothing)
(x = Leaf(Momentum{Float32}(0.01, 0.9), [0.0, 0.0, 0.0]), y = ((), ()), z = ())
```
The recursion into structures uses Functors.jl, and any new `struct`s containing parameters
Expand All @@ -82,15 +90,15 @@ julia> struct Layer; mat; fun; end
julia> model = (lay = Layer([1 2; 3 4f0], sin), vec = [5, 6f0]);
julia> Optimisers.setup(Momentum(), model) # new struct is by default ignored
(lay = nothing, vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0]))
(lay = (), vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0]))
julia> destructure(model)
(Float32[5.0, 6.0], Restructure(NamedTuple, ..., 2))
julia> using Functors; @functor Layer # annotate this type as containing parameters
julia> Optimisers.setup(Momentum(), model)
(lay = (mat = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0 0.0; 0.0 0.0]), fun = nothing), vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0]))
(lay = (mat = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0 0.0; 0.0 0.0]), fun = ()), vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0]))
julia> destructure(model)
(Float32[1.0, 3.0, 2.0, 4.0, 5.0, 6.0], Restructure(NamedTuple, ..., 6))
Expand All @@ -112,12 +120,12 @@ See also [`update!`](@ref), which will be faster for models of ordinary `Array`s
julia> m = (x = Float32[1,2,3], y = tanh);
julia> t = Optimisers.setup(Descent(0.1f0), m)
(x = Leaf(Descent{Float32}(0.1), nothing), y = nothing)
(x = Leaf(Descent{Float32}(0.1), nothing), y = ())
julia> g = (x = [1,1,1], y = nothing); # fake gradient
julia> Optimisers.update(t, m, g)
((x = Leaf(Descent{Float32}(0.1), nothing), y = nothing), (x = Float32[0.9, 1.9, 2.9], y = tanh))
((x = Leaf(Descent{Float32}(0.1), nothing), y = ()), (x = Float32[0.9, 1.9, 2.9], y = tanh))
```
"""
update
Expand Down Expand Up @@ -157,8 +165,8 @@ true
julia> m # original should be discarded, may be mutated but no guarantee
(x = Float32[0.6666666, 1.5333333], y = Float32[4.0, 5.0])
julia> t # original state should likewise be discarded
(x = Leaf(Momentum{Float64}(0.0333333, 0.9), Float32[0.333333, 0.466667]), y = Leaf(Momentum{Float64}(0.0333333, 0.9), Float32[0.0, 0.0]))
julia> t == t2 # original state is in fact guaranteed to be mutated
true
```
"""
update!
Expand Down
6 changes: 3 additions & 3 deletions src/adjust.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ To change just the learning rate, provide a number `η::Real`.
julia> m = (vec = rand(Float32, 2), fun = sin);
julia> st = Optimisers.setup(Nesterov(), m) # stored momentum is initialised to zero
(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[0.0, 0.0]), fun = nothing)
(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[0.0, 0.0]), fun = ())
julia> st, m = Optimisers.update(st, m, (vec = [16, 88], fun = nothing)); # with fake gradient
julia> st
(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[-0.016, -0.088]), fun = nothing)
(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[-0.016, -0.088]), fun = ())
julia> st = Optimisers.adjust(st, 0.123) # change learning rate, stored momentum untouched
(vec = Leaf(Nesterov{Float32}(0.123, 0.9), Float32[-0.016, -0.088]), fun = nothing)
(vec = Leaf(Nesterov{Float32}(0.123, 0.9), Float32[-0.016, -0.088]), fun = ())
```
To change other parameters, `adjust` also accepts keyword arguments matching the field
Expand Down
131 changes: 95 additions & 36 deletions src/interface.jl
Original file line number Diff line number Diff line change
@@ -1,57 +1,120 @@

using ChainRulesCore: canonicalize, backing, Tangent, AbstractZero
using ChainRulesCore: canonicalize, backing, Tangent, AbstractZero, ZeroTangent
base(dx::Tangent) = backing(canonicalize(dx))
base(dx) = dx
const Zero = Union{Nothing, AbstractZero} # Union{Zygote, Diffractor}

abstract type AbstractRule end

struct Leaf{R,S}
###
### setup
###

mutable struct Leaf{R,S} # mutable so that its identity encodes parameter sharing
rule::R
state::S
end

function setup(rule, x; seen = Base.IdSet())
rule isa AbstractRule || Base.depwarn("In future, all optimisation rules should be <: AbstractRule", :setup)
@functor Leaf

Base.:(==)(a::Leaf, b::Leaf) = children(a) == children(b)

function setup(rule::AbstractRule, model)
cache = IdDict()
tree = _setup(rule, model; cache)
isempty(cache) && @warn "setup found no trainable parameters in this model"
tree
end

# _setup is almost fmapstructure, but needs a _trainable_walk, and a cache which ignores numbers etc.
function _setup(rule, x; cache)
haskey(cache, x) && return cache[x]
if isnumeric(x)
x in seen && throw(ArgumentError("Optimisers.jl does not at present handle tied weights, sorry."))
isbits(x) || push!(seen, x)
return Leaf(rule, init(rule, x))
elseif isleaf(x)
return nothing
= Leaf(rule, init(rule, x))
if isbits(x)
cache[nothing] = nothing # just to disable the warning
else
cache[x] =
end
else
return map(xᵢ -> setup(rule, xᵢ; seen), _trainable(x))
map(xᵢ -> _setup(rule, xᵢ; cache), _trainable(x))
end
end

subtract!(x, x̄) = maywrite(x) ? (x .= x .- x̄) : eltype(x).(x .- x̄)
function Base.show(io::IO, ℓ::Leaf) # show method is mostly to hide its long type!
ioc = IOContext(io, :compact => true)
print(ioc, "Leaf(", ℓ.rule, ", ")
show(ioc, ℓ.state)
print(ioc, ")")
end

update!(::Nothing, x, ::Zero, ::Zero...) = nothing, x
update!(::Nothing, x, x̄s...) = nothing, x
###
### update
###

update!(ℓ::Leaf, x, ::Zero, ::Zero...) = ℓ, x
function update!(ℓ::Leaf, x, x̄s...)
s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, base.(x̄s)...)
Leaf(ℓ.rule, s′), subtract!(x, x̄′)
function update(tree, model, grad, higher...)
t′ = fmap(copy, tree; exclude = maywrite) # walks inside Leaf
x′ = fmap(copy, model; exclude = maywrite)
update!(t′, x′, grad, higher...)
end

update!(tree, x, ::Zero, ::Zero...) = tree, x
function update!(tree, x, x̄s...)
x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s)
x′, re = functor(typeof(x), x)
xtree = map((stᵢ, xᵢ, x̄sᵢ...) -> update!(stᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...)
map(first, xtree), re(map(last, xtree))
function update!(tree, model, grad, higher...)
# First walk is to accumulate the gradient. This recursion visits every copy of
# shared leaves, but stops when branches are absent from the gradient:
grads = IdDict{Leaf, Any}()
_grads!(grads, tree, model, grad, higher...)
# Second walk is to update the model. The params cache indexed by (tree,x),
# so that identified Leafs can tie isbits parameters, but setup won't do that for you:
newmodel = _update!(tree, model; grads, params = IdDict())
tree, newmodel # note that tree is guaranteed to be updated. Also that it's not necc a tree.
end

function _update!(tree, x; grads, params)
haskey(params, (tree,x)) && return params[(tree,x)]
isbits(tree) && return x # means () is not cached, and also (((),),)
x′, re = functor(x)
x′′ = re(map((tᵢ, xᵢ) -> _update!(tᵢ, xᵢ; grads, params), tree, x′))
if ismutable(x′′)
params[(tree,x)] = x′′
else # no ties to preserve between immutable structs, right?
x′′
end
end
function _update!(ℓ::Leaf, x; grads, params)
haskey(params, (ℓ,x)) && return params[(ℓ,x)]
params[(ℓ,x)] = if haskey(grads, ℓ)
.state, x̄′ = apply!(ℓ.rule, ℓ.state, x, grads[ℓ]...)
subtract!(x, x̄′)
else
x # no gradient seen
end
end

subtract!(x, x̄) = maywrite(x) ? (x .= x .- x̄) : eltype(x).(x .- x̄)

function update(tree, x, x̄s...)
t′ = fmap(copy, tree; exclude = maywrite)
x′ = fmap(copy, x; exclude = maywrite)
update!(t′, x′, x̄s...)
_grads!(dict::IdDict, ℓ::Leaf, x, ::Zero...) = nothing
function _grads!(dict::IdDict, ℓ::Leaf, x, x̄s...)
x̄s₀ = get(dict, ℓ, map(_ -> ZeroTangent(), x̄s))
dict[ℓ] = map(+, x̄s, x̄s₀) # adding Zero should be free. Lazy accumulation broadcasted(+, x̄, x̄₀) also possible.
nothing
end
_grads!(dict::IdDict, t, x, ::Zero...) = nothing
function _grads!(dict::IdDict, tree, x, x̄s...)
# The only reason _grads! takes model is that functor(typeof(x), base(x̄)) may differ from
# functor(typeof(tree), base(x̄)), for things like Transpose
x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s)
x′, _ = functor(typeof(x), x)
foreach((tᵢ, xᵢ, x̄sᵢ...) -> _grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...)
end

# default all rules to first order calls
apply!(o, state, x, dx, dx2, dxs...) = apply!(o, state, x, dx)

###
### sources of truth
###

"""
isnumeric(x) -> Bool
Expand Down Expand Up @@ -98,8 +161,12 @@ function _trainable(ch::NamedTuple, tr::Tuple) # for old Flux-style no-names tu
map(c -> c in tr ? c : nothing, ch)
end

###
### rule definition helpers
###

"""
@.. x = x + y
@.. x = y + z
Sometimes in-place broadcasting macro, for use in `apply!` rules.
If `maywrite(x)` then it is just `@. x = rhs`, but if not, it becomes `x = @. rhs`.
Expand Down Expand Up @@ -135,11 +202,3 @@ Broadcast.materialize(x::Lazy) = Broadcast.instantiate(x.bc)

onevalue::T, x::AbstractArray{T}) where T = map(_ -> λ, x)
onevalue(λ, x::AbstractArray{T}) where T = onevalue(convert(float(T), λ), x)

function Base.show(io::IO, ℓ::Leaf) # show method is mostly to hide its long type!
ioc = IOContext(io, :compact => true)
print(ioc, "Leaf(", ℓ.rule, ", ")
show(ioc, ℓ.state)
print(io, ")")
end

Loading

2 comments on commit 9c12e5d

@mcabbott
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/70138

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.10 -m "<description of version>" 9c12e5d1214e80d90de253565c89869f93ae6eed
git push origin v0.2.10

Please sign in to comment.