Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zygote.gradient returns nothing when differentiating vectorized sampling. #71

Open
arnauqb opened this issue Feb 20, 2023 · 1 comment

Comments

@arnauqb
Copy link

arnauqb commented Feb 20, 2023

Hi there, first of all, thanks for this great package. I'm trying to understand how StochasticAD integrates with Zygote and I ran into a problem when trying to differentiate a "dot" operation. Here is the code:

using StochasticAD, Distributions, Zygote

function sample_n_bernoullis(p, n)
    return sum([rand(Bernoulli(p)) for i in 1:n])
end

function sample_n_bernoullis_vectorized(p, n)
    probs = p * ones(n)
    return sum(rand.(Bernoulli.(probs)))
end

n = 100
p = 0.5

derivative_estimate(p -> sample_n_bernoullis(p, n), p) # this works
derivative_estimate(p -> sample_n_bernoullis_vectorized(p, n), p) # this works

Zygote.gradient(p -> sample_n_bernoullis(p, n), p) # this works
Zygote.gradient(p -> sample_n_bernoullis_vectorized(p, n), p) # this doesn't work (returns nothing)

Are vectorized operations not supported with Zygote?

Thanks!

@gaurav-arya
Copy link
Owner

gaurav-arya commented Mar 13, 2023

Hey, apologies for the late reply. I've been taking a look into this, and it seems to be an interaction between Zygote's broadcast differentiation machinery and the fact that rand ∘ Bernoulli produces a boolean. Things work fine for a geometric:

function sample_n_geometrics_vectorized(p, n)
    probs = p * ones(n)
    return sum(rand.(Geometric.(probs)))
end
Zygote.gradient(p -> sample_n_geometrics_vectorized(p, n), p)  # works fine

So while this is not fixed, a couple alternatives you have are to rewrite using a map or make sure the broadcasted function doesn't return a boolean:

function sample_n_bernoullis_vectorized_1(p, n)
     probs = p * ones(n)
     return (x -> rand(Bernoulli(x))*1).(probs) |> sum
end

function sample_n_bernoullis_vectorized_2(p, n)
       probs = p * ones(n)
       return map(rand  Bernoulli, probs) |> sum
end

Zygote.gradient(p -> sample_n_bernoullis_vectorized_1(p, n), p)  # works fine
Zygote.gradient(p -> sample_n_bernoullis_vectorized_2(p, n), p)  # works fine

Thanks for catching this!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants