You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi there, first of all, thanks for this great package. I'm trying to understand how StochasticAD integrates with Zygote and I ran into a problem when trying to differentiate a "dot" operation. Here is the code:
using StochasticAD, Distributions, Zygote
functionsample_n_bernoullis(p, n)
returnsum([rand(Bernoulli(p)) for i in1:n])
endfunctionsample_n_bernoullis_vectorized(p, n)
probs = p *ones(n)
returnsum(rand.(Bernoulli.(probs)))
end
n =100
p =0.5derivative_estimate(p ->sample_n_bernoullis(p, n), p) # this worksderivative_estimate(p ->sample_n_bernoullis_vectorized(p, n), p) # this works
Zygote.gradient(p ->sample_n_bernoullis(p, n), p) # this works
Zygote.gradient(p ->sample_n_bernoullis_vectorized(p, n), p) # this doesn't work (returns nothing)
Are vectorized operations not supported with Zygote?
Thanks!
The text was updated successfully, but these errors were encountered:
Hey, apologies for the late reply. I've been taking a look into this, and it seems to be an interaction between Zygote's broadcast differentiation machinery and the fact that rand ∘ Bernoulli produces a boolean. Things work fine for a geometric:
functionsample_n_geometrics_vectorized(p, n)
probs = p *ones(n)
returnsum(rand.(Geometric.(probs)))
end
Zygote.gradient(p ->sample_n_geometrics_vectorized(p, n), p) # works fine
So while this is not fixed, a couple alternatives you have are to rewrite using a map or make sure the broadcasted function doesn't return a boolean:
functionsample_n_bernoullis_vectorized_1(p, n)
probs = p *ones(n)
return (x ->rand(Bernoulli(x))*1).(probs) |> sum
endfunctionsample_n_bernoullis_vectorized_2(p, n)
probs = p *ones(n)
returnmap(rand ∘ Bernoulli, probs) |> sum
end
Zygote.gradient(p ->sample_n_bernoullis_vectorized_1(p, n), p) # works fine
Zygote.gradient(p ->sample_n_bernoullis_vectorized_2(p, n), p) # works fine
Hi there, first of all, thanks for this great package. I'm trying to understand how StochasticAD integrates with Zygote and I ran into a problem when trying to differentiate a "dot" operation. Here is the code:
Are vectorized operations not supported with Zygote?
Thanks!
The text was updated successfully, but these errors were encountered: