-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a logabsdet_jacobian
method
#3
Comments
You mean logabsdet_jacobian(f, x) = with_logabsdet_jacobian(f, x)[2] as a convenience? Hm, I'm not so sure - writing Is there a use case beyond convenience? |
It creates overhead when broadcasting: julia> f(x) = (2x, 3x)
f (generic function with 1 method)
julia> f1(x) = 2x
f1 (generic function with 1 method)
julia> f2(x) = 3x
f2 (generic function with 1 method)
julia> xs = randn(100_000);
julia> @benchmark $f1.($xs)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 70.233 μs … 1.699 ms ┊ GC (min … max): 0.00% … 79.00%
Time (median): 97.174 μs ┊ GC (median): 0.00%
Time (mean ± σ): 117.961 μs ± 108.078 μs ┊ GC (mean ± σ): 10.99% ± 10.97%
▄██▆▅▃▂▂▁▁▁ ▂
██████████████▇▇▆▅▄▃▄▃▁▃▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▁▁▅▆▇▆▆▆▆ █
70.2 μs Histogram: log(frequency) by time 812 μs <
Memory estimate: 781.33 KiB, allocs estimate: 2.
julia> @benchmark $f2.($xs)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 63.100 μs … 1.900 ms ┊ GC (min … max): 0.00% … 64.76%
Time (median): 96.844 μs ┊ GC (median): 0.00%
Time (mean ± σ): 117.607 μs ± 108.086 μs ┊ GC (mean ± σ): 11.08% ± 11.01%
▆█▆▅▄▃▂▁▁▁ ▂
▇████████████▇▆▆▆▅▃▃▅▄▁▁▁▁▁▁▁▁▁▃▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▆▆▆▆▅▆▆ █
63.1 μs Histogram: log(frequency) by time 821 μs <
Memory estimate: 781.33 KiB, allocs estimate: 2.
julia> @benchmark $(Base.Fix2(getindex, 1) ∘ f).($xs)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 86.023 μs … 1.877 ms ┊ GC (min … max): 0.00% … 81.82%
Time (median): 128.063 μs ┊ GC (median): 0.00%
Time (mean ± σ): 155.207 μs ± 121.756 μs ┊ GC (mean ± σ): 9.08% ± 10.73%
▄█▄
▃████▆▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂ ▃
86 μs Histogram: frequency by time 889 μs <
Memory estimate: 781.33 KiB, allocs estimate: 2.
julia> @benchmark $(Base.Fix2(getindex, 2) ∘ f).($xs)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 86.414 μs … 1.591 ms ┊ GC (min … max): 0.00% … 84.73%
Time (median): 120.063 μs ┊ GC (median): 0.00%
Time (mean ± σ): 142.219 μs ± 110.132 μs ┊ GC (mean ± σ): 9.23% ± 10.71%
▃▇█▇▅▄▃▃▂▂▁ ▂
█████████████▇▇█▆▆▅▅▅▃▃▁▄▁▄▁▄▃▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▃▄▆▆▇▇▆▇ █
86.4 μs Histogram: log(frequency) by time 838 μs <
Memory estimate: 781.33 KiB, allocs estimate: 2. This might also be worse when combined with AD, etc. It came to mind because in Bijectors.jl we also want to define multivariate versions of the many transformations, e.g. elementwise EDIT: Also, in general, you might want to just compute the logabsdet-jacobian term, in which case you'll end up doing unnecessary computation in |
I would have thought that scenario very uncommon, needing only the ladj of a trafo but not the result. But if there are use cases, it would make sense to support it directly. Would you have an example or two (we could also add that to the docs then, maybe)? |
Regarding broadcasting overhead and autodiff, that's an interesting question. We do have broadcasting support in ChangesOfVariables, but we didn't really benchmark it so far. Let's use a broadcased-log as a test case, with a simple "loss function" (just using ChangesOfVariables, LinearAlgebra, Zygote, BenchmarkTools
logabsdet_jacobian(f, x) = with_logabsdet_jacobian(f, x)[2]
function foo(xs)
ys = log.(xs)
ladj = sum(logabsdet_jacobian.(log, xs))
dot(ys, ys) + ladj
end
grad_foo(xs) = Zygote.gradient(foo, xs)
function bar(xs)
ys, ladj = with_logabsdet_jacobian(Base.Fix1(broadcast, log), xs)
dot(ys, ys) + ladj
end
grad_bar(xs) = Zygote.gradient(bar, xs) Benchmarking-wise, I get julia> xs = rand(10^3);
julia> @benchmark foo($xs)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 11.530 μs … 1.776 ms ┊ GC (min … max): 0.00% … 97.74%
Time (median): 13.678 μs ┊ GC (median): 0.00%
Time (mean ± σ): 14.012 μs ± 17.694 μs ┊ GC (mean ± σ): 1.24% ± 0.98%
▁ ▂▂ ▁▃▂▄▅▆▄██▅▄█▇▃▂▂▂▁ ▂
██▇▇▆▄▅▅▄▁▆██████████████████████▆▁▃▁▁▁▃▁▁▄▄▄▅▃▅▃▅▅▃▅▃▄▃▃▁▃ █
11.5 μs Histogram: log(frequency) by time 17.2 μs <
Memory estimate: 15.88 KiB, allocs estimate: 2.
julia> @benchmark bar($xs)
BenchmarkTools.Trial: 10000 samples with 4 evaluations.
Range (min … max): 7.894 μs … 393.329 μs ┊ GC (min … max): 0.00% … 95.37%
Time (median): 8.907 μs ┊ GC (median): 0.00%
Time (mean ± σ): 9.813 μs ± 16.138 μs ┊ GC (mean ± σ): 8.48% ± 5.03%
▁▃▆██▆▄
▁▁▁▂▂▃▄▆▇████████▅▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
7.89 μs Histogram: frequency by time 12.6 μs <
Memory estimate: 31.62 KiB, allocs estimate: 3.
julia> @benchmark grad_foo($xs)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 32.637 μs … 2.148 ms ┊ GC (min … max): 0.00% … 95.02%
Time (median): 37.748 μs ┊ GC (median): 0.00%
Time (mean ± σ): 44.188 μs ± 102.325 μs ┊ GC (mean ± σ): 12.09% ± 5.11%
▂▄▆██▇▇▆▄▄▃▂▁▁▁▁ ▁ ▂
▄▄▅▇██████████████████▇▆▇▇▆▆▄▆▇████▆▆▄▃▄▄▆▄▄▅▄▃▃▃▃▄▁▁▃▄▁▄▁▁▄ █
32.6 μs Histogram: log(frequency) by time 67.3 μs <
Memory estimate: 129.48 KiB, allocs estimate: 83.
julia> @benchmark grad_bar($xs)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 41.127 μs … 2.765 ms ┊ GC (min … max): 0.00% … 95.62%
Time (median): 47.307 μs ┊ GC (median): 0.00%
Time (mean ± σ): 61.313 μs ± 140.068 μs ┊ GC (mean ± σ): 16.32% ± 6.97%
▃▆██▆▄▄▃▃▂▂▂▂▁▁ ▂
▇█████████████████▇▇▇▇▇▆▇▆▆▆▇▆▅▇▇▇▇▇▇▆▆▆▆▆▆▆▆▆▄▅▆▆▅▄▅▄▅▃▄▄▂▅ █
41.1 μs Histogram: log(frequency) by time 126 μs <
Memory estimate: 248.20 KiB, allocs estimate: 99. So using the |
Turns out that with a custom pullback for using ChainRulesCore
function _with_ladj_on_mapped_pullback(thunked_ΔΩ)
ys, ladj = ChainRulesCore.unthunk(thunked_ΔΩ)
NoTangent(), NoTangent(), broadcast(x -> (x, ladj), ys)
end
function ChainRulesCore.rrule(::typeof(ChangesOfVariables._with_ladj_on_mapped), map_or_bc::Function, y_with_ladj)
return ChangesOfVariables._with_ladj_on_mapped(map_or_bc, y_with_ladj), _with_ladj_on_mapped_pullback
end we can make AD on with_logabsdet_jacobian(Base.Fix1(broadcast, f), x) significantly faster: julia> @benchmark foo($xs)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 11.493 μs … 2.092 ms ┊ GC (min … max): 0.00% … 98.32%
Time (median): 13.088 μs ┊ GC (median): 0.00%
Time (mean ± σ): 14.712 μs ± 21.011 μs ┊ GC (mean ± σ): 1.40% ± 0.98%
▅▂▃▃▅▇█▃▂ ▂▄▅▆▅▅▅▄▃▂ ▂
█████████▅▄▃▄████████████▇▅▆▆▅▆▆▇▆████▇█▇▇▅▅▅▃▃▃▁▃▃▃▃▁▁▁▅▃▅ █
11.5 μs Histogram: log(frequency) by time 25.9 μs <
Memory estimate: 15.88 KiB, allocs estimate: 2.
julia> @benchmark bar($xs)
BenchmarkTools.Trial: 10000 samples with 4 evaluations.
Range (min … max): 7.452 μs … 449.738 μs ┊ GC (min … max): 0.00% … 95.81%
Time (median): 8.578 μs ┊ GC (median): 0.00%
Time (mean ± σ): 9.610 μs ± 15.170 μs ┊ GC (mean ± σ): 8.11% ± 5.05%
▂▅▆▇▇██▇▆▄▂▁ ▁ ▂
▆█████████████▇▇▅▅▄▅▅▁▄▄▆▆▇█▇▇▇▅▅▅▅▃▃▇███▇▆▆████▆▃▅▅▁▅▆▃▅▄▅ █
7.45 μs Histogram: log(frequency) by time 17.1 μs <
Memory estimate: 31.62 KiB, allocs estimate: 3.
julia> @benchmark grad_foo($xs)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 29.964 μs … 2.371 ms ┊ GC (min … max): 0.00% … 93.51%
Time (median): 37.055 μs ┊ GC (median): 0.00%
Time (mean ± σ): 43.938 μs ± 103.986 μs ┊ GC (mean ± σ): 12.38% ± 5.15%
▂▄▇███▇▆▄▄▃▃▂▁▁ ▁ ▂
▄▁▁▆█████████████████▇▇▇████▇█▇▇██▇▆▆▇▆▇▇▆▅▃▃▅▃▆▆▅▄▅▆▃▅▆▄▆▅▆ █
30 μs Histogram: log(frequency) by time 77.2 μs <
Memory estimate: 129.48 KiB, allocs estimate: 83.
julia> @benchmark grad_bar($xs)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 22.926 μs … 2.078 ms ┊ GC (min … max): 0.00% … 97.35%
Time (median): 27.223 μs ┊ GC (median): 0.00%
Time (mean ± σ): 33.296 μs ± 90.284 μs ┊ GC (mean ± σ): 14.57% ± 5.30%
▃▆██▇▅▄▂▂▂▁ ▂
▃▇█████████████▇▆▇▇▄▆▇██▇▆▆▆▅▄▅▄▄▃▃▄▄▄▃▂▄▄▄▄▅▅▅▃▄▄▃▄▄▄▃▃▄▃▅ █
22.9 μs Histogram: log(frequency) by time 68.8 μs <
Memory estimate: 136.44 KiB, allocs estimate: 50. @devmotion , this would have quite an impact (almost a factor two speedup in my simple example), but it would mean adding ChainRulesCore as a dependency. What do you think? It would make ChangesOfVariables itself less lightweight - but on the other hand, maybe pretty much every package that we'd hope would support ChangesOfVariables will depend on ChainRulesCore already anyway (directly or indirectly)? |
It can be inconvenient to work with tuples but in fact it is very common in the AD/ChainRules setting but also eg in Functors or ParameterHandling. It is common, for instance, to just collect the outputs separately from the resulting array of tuples with I am not sure if the benchmark example is completely representative. Clearly, if the ladj and the output can be computed separately, one can save computations by only computing the Maybe an API such as the experimental If possible, an alternative for mapping/broadcasting could be to use |
I have to admit I'm also not convinced that there will be many use cases that don't require the result of the transformation at all. @devmotion, how would you feel about a custom pullback for the internal |
The Sometimes the API is just in such a way that you don't need the transformed variable, e.g.: And as mentioned, sometimes you don't actually need anything from the "forward" evaluation, i.e. it's a completely separate computation. Even if we don't encourage people to implement this, IMO we should at least have a default implementation, i.e. just Or let me put it like this: why shouldn't we have a |
Let's discuss the custom rrule cost/benfit (not directly related to the question of adding |
Oh, sure! But the implementation of But you do have a point - there are cases where the ladj-calculation does not share any code with the function and it's inverse, and ladj is easier to calculate in one direction than the other. So there, people would anyway write something like _logabsdet_jacobian_only(::typeof(myfunc), x) = ...
with_logabsdet_jacobian(::typeof(myfunc), x) = myfunc(x), _logabsdet_jacobian_only(myfunc, x)
function with_logabsdet_jacobian(::typeof(inv_myfunc), y)
x = inv_myfunc(y)
return x, - _logabsdet_jacobian_only(myfunc, x)
end I know I have code like that in some places. Might as well give
I'm not sure if I understand that code correctly - the transformation not applied at all in that use case (even before logpdf_with_trans is called)?
Hm, you do have a point there. I'm still not sure about the use case on the end-user side - but then, our AD frameworks all offer So I'm not against adding logabsdet_jacobian(f, x) = last(with_logabsdet_jacobian(f, x)) If we do though, it should be clearly documented that people always have to specialize |
Wouldn't it be suffcient to add implementations of function (::ComposedFunction{typeof(last),typeof(with_logabsdet_jacobian)})(f::MyFunction, x)
...
end if (::ComposedFunction{typeof(first),typeof(with_logabsdet_jacobian)})(f, x) = f(x) Then one could exploit optimizations of |
I think even if a const logabsdet_jacobian = last \circ with_logabsdet_jacobian |
I think it would work, but it's not exactly very readable or convenient. What would be the advantage of taking the It would not play well with use cases like logabsdet_jacobian(::typeof(myfunc), x) = ...
with_logabsdet_jacobian(::typeof(myfunc), x) = myfunc(x), logabsdet_jacobian(myfunc, x)
function with_logabsdet_jacobian(::typeof(inv_myfunc), y)
x = inv_myfunc(y)
return x, - logabsdet_jacobian(myfunc, x)
end at least (typical use case if ladj is better calculated in one direction and there's no advantage in calculating myfunc and it's ladj together (no shared code / synergy). |
That the optimized implementation would also be used if someone calls |
Sure, but only in that case. Most people would probably use an anonymous function, I don't think we'd get a lot of opportunistic optimization that way.
They are, but especially if one want to implement a specialized |
Even if most people use function ChainRulesCore.rrule(::typeof(logabsdet_jacobian), ::MyFunction, x)
...
end In fact, this is what should be recommended to users anyway. It's merely an implementation detail. |
You're right, for the rrule it's fine. But for the use case I mentioned, using |
Though I agree that it's possible to overload just the composition, I'm not a big fan due to readability 🤷 |
Related to TuringLang/Bijectors.jl#212. |
Should we also have a method
logabsdet_jacobian
which is equivalent toBase.Fix(getindex, 2) ∘ with_logabsdet_jacobian
, i.e. it only computes the logabsdet-jacobian term?The text was updated successfully, but these errors were encountered: