Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stochastic triples not passed through propagate lose their perturbations #128

Open
GuusAvis opened this issue Jun 15, 2024 · 7 comments
Open

Comments

@GuusAvis
Copy link

GuusAvis commented Jun 15, 2024

Hello!

I have been running into the following problem. When I use propagate to pass a stochastic triple to a function, but in that function there is a second stochastic triple that was not passed to it through propagate, the perturbations of that second stochastic triple are not taken into account. See below a simple example of adding two numbers together.

using StochasticAD

"""Helper function to manually define a stochastic triple."""
function stoch_trip(val::Real, inf_pert::Real, fin_pert::Real, prob::Real)
    Δs = StochasticAD.similar_new(StochasticAD.create_Δs(PrunedFIsBackend(), Int),
        fin_pert, prob)
    StochasticAD.StochasticTriple{0}(val, inf_pert, Δs)
end

x = stoch_trip(1.0, 0.0, 100.0, 1.0)  # 1.0 + 0.0ε + (100.0 with probability 1.0ε)
c = stoch_trip(10.0, 0.0, 1000.0, 100.0)  # 10.0 + 0.0ε + (1000.0 with probability 100.0ε)

f(x) = StochasticAD.propagate(y -> y + c, x)

println(f(x))  # 11.0 + 0.0ε + (100.0 with probability 1.0ε)
println(x + c)  # 11.0 + 0.0ε + (1000.0 with probability 101.0ε) (with high probability)

Just to be clear, I would expect both print statements to give the same output, but they do not. It appears that f(x) completely ignores the perturbation of c.

I realize that my example can easily be fixed by making c an argument of f and then also passing that through propagate. However, I think this is not always desirable or feasible. In my particular use case, conditionally on some if statements that are handled by a call to propagate, I am sampling a distribution. The sample obtained from that distribution is a stochastic triple, the perturbations of which are important to the final result. Because the sample is only created within the function that is passed to propagate, I cannot (easily) pull it out of the function to include it in the call. (As far as I understand, it is not possible to propagate distributions in a meaningful way.)

I would very much like to understand if I am doing something wrong here or whether this could be considered a bug. If it is a bug, could it be fixed? Having this work would greatly increase the value of StochasticAD to my project and I would be happy to help towards that end!

Many thanks!

@GuusAvis GuusAvis changed the title stochastic triples not passed through propagate stochastic triples not passed through propagate lose their perturbations Jun 15, 2024
@gaurav-arya
Copy link
Owner

Hey! It is indeed correct that StochasticAD.propagate assumes that the provided function do not close over variables which themselves have derivative contributions. It's intended mainly for functions which are discrete deterministic.

If possible, I think it might be easier for me to help with a bit more representative MWE of your situation. My first point of attack would be to attempt to decompose your function into (possibly multiple) discrete deterministic functions that we can handle via propagate, as well as primitive sampling steps that StochasticAD can handle. If this isn't possible, we can see what we can do.

@GuusAvis
Copy link
Author

Many thanks for the reply!

Yes I have also thought about pulling out the randomness to make all functions deterministic, the problem however is that the amount of randomness that is needed is itself random. That makes it difficult to obtain all the random samples ahead of time. An option could be to sample a fixed amount of randomness ahead of time and hope it is enough, but this runs the risk of biasing the distribution I am trying to sample and/or throwing unpredictable errors. It would also be inefficient and not very elegant.

To make it a bit more concrete, consider the following MWE. We are playing a "step game", where two values are increased in discrete steps with the goal of decreasing the difference between the two values below some tolerance while keeping the values themselves small. To this end, we recursively increase the smaller of the two values by one step, and then check if the difference is within the tolerance. In Julia, we can model this using the following function:

using Distributions, StochasticAD

function step_game(rand_var_1, rand_var_2, tolerance)
    value_1 = rand(rand_var_1)
    value_2 = rand(rand_var_2)
    while abs(value_1 - value_2) > tolerance
        if value_1 < value_2
            value_1 += rand(rand_var_1)
        else
            value_2 += rand(rand_var_2)
        end
    end
    value_1, value_2
end

Then, we can for instance see what happens when the two random variables are geometrically distributed:

rand_var_1 = Geometric(0.1)
rand_var_2 = Geometric(0.1)
tolerance = 2.
step_game(rand_var_1, rand_var_2, tolerance)

Now, I want to learn how the output values change as we change the random variables. More precisely, I want to be able to evaluate

rand_var_1 = Geometric(stochastic_triple(0.1))
rand_var_2 = Geometric(0.1)
tolerance = 2.
step_game(rand_var_1, rand_var_2, tolerance)

This, of course, will not work because the conditional statements in step_game() do not play nice with the stochastic triples obtained from rand(rand_var_1). My somewhat naive solution so far has been to split the function into an initialization part and a recursive function that takes care of updating the values. This second function can then call propagate when dispatching on stochastic triples. The code looks as follows:

function step_game_2(rand_var_1, rand_var_2, tolerance)
    value_1 = rand(rand_var_1)
    value_2 = rand(rand_var_2)
    recursive_update_step_game(value_1, value_2, rand_var_1, rand_var_2, tolerance)
end

function recursive_update_step_game(value_1, value_2, rand_var_1, rand_var_2, tolerance)
    if abs(value_1 - value_2) <= tolerance
        return value_1, value_2
    end
    if value_1 < value_2
        value_1 += rand(rand_var_1)
    else
        value_2 += rand(rand_var_2)
    end
    recursive_update_step_game(value_1, value_2, rand_var_1, rand_var_2, tolerance)
end 

function recursive_update_step_game(value_1::StochasticAD.StochasticTriple, value_2,
        rand_var_1, rand_var_2, tolerance)
    f = x -> recursive_update_step_game(x, value_2, rand_var_1, rand_var_2, tolerance)
    StochasticAD.propagate(f, value_1)
end
function recursive_update_step_game(value_1, value_2::StochasticAD.StochasticTriple,
        rand_var_1, rand_var_2, tolerance)
    f = x -> recursive_update_step_game(value_1, x, rand_var_1, rand_var_2, tolerance)
    StochasticAD.propagate(f, value_2)
end
function recursive_update_step_game(value_1::StochasticAD.StochasticTriple,
        value_2::StochasticAD.StochasticTriple, rand_var_1, rand_var_2, tolerance)
    f = (x, y) -> recursive_update_step_game(x, y, rand_var_1, rand_var_2, tolerance)
    StochasticAD.propagate(f, value_1, value_2)
end

The function step_game_2() should be equivalent to step_game(), at least when there are no stochastic triples involved.
Now, this naive solution has two problems:

  1. It is very prone to stack-overflow errors. Because we have nested propagates, we keep splitting stochastic triples into their 0th order and 1st order parts, which I think is resulting in an exponential number of function calls.
  2. The returned values are often incorrect. The 1st-order parts of the samples obtained within a propagate are not returned, meaning that often we get something like (4 + 0ε + (0 with probability 9.523809523809526ε), 5 + 0ε + (0 with probability 9.523809523809526ε)) when it has to update the values a couple of times, even though we would expect nonzero perturbations. (This is basically the problem pointed out in my original post.)

As discussed above, the problem could perhaps be addressed by creating a reservoir of randomness in step_game_2() and passing that to the recursive function (using propagate for those values as well). However it has the downsides I mentioned and would love it if some other solution were possible.

Thanks!

@gaurav-arya
Copy link
Owner

Just wanted to apologize for the delay in replying -- been a busy time. I'll try to circle back as soon as I can!

@GuusAvis
Copy link
Author

GuusAvis commented Jul 5, 2024

Thanks for the update @gaurav-arya , no problem of course. Looking forward to when you have time.

To perhaps push the discussion further, I can report here on a failed attempt on my side to get a working but hacky solution. However, if you think this post is not helpful for finding a good solution, please disregard it.

My thinking was that if propagate could be adapted such that

  1. the perturbations returned from the primal evaluation are taken into account and
  2. it is possible to have awareness of whether a function evaluation is primal or perturbed

that would make it possible both to obtain correct results and limit the number of function evaluations by avoiding the accidental calculation of perturbations to perturbations (i.e., stick to first-order effects and avoid stack overflows).

This led to the following alternative (and pretty bad) implementation of propagate:

using StochasticAD, Distributions

function stoch_trip(val::Real, inf_pert::Real, fin_pert::Real, prob::Real)
    Δs = StochasticAD.similar_new(StochasticAD.create_Δs(PrunedFIsBackend(), Int),
        fin_pert, prob)
    StochasticAD.StochasticTriple{0}(val, inf_pert, Δs)
end

function perturbed_value(x::StochasticAD.StochasticTriple)
    StochasticAD.value(x) + StochasticAD.perturbations(x)[1][1]
end
perturbed_value(x) = x

function propagate_2(f, args...)
    # determine the primal value (note: it may be a stochastic triple)
    primal_args = StochasticAD.value.(args)
    primal_result = nothing
    try
        primal_result = f(primal_args..., perturbative=false)
    catch e
        if e isa MethodError
            primal_result = f(primal_args...)
        else
            rethrow(e)
        end
    end

    # "collapse" the arguments, such that we are in a definite branch
    sum(args)  # this is probably a really bad way of doing this

    # determine epsilon of the branch we are in (very hacky)
    perturbations = StochasticAD.perturbations.(args)
    epsilon = nothing
    for perturbation in perturbations
        if perturbation[1][1] != 0.0
            epsilon = perturbation[1][2]
            break
        end
    end
    isnothing(epsilon) && (epsilon = perturbations[1][1][2])

    # determine the perturbed value
    perturbed_args = perturbed_value.(args)
    perturbed_result = nothing
    try
        perturbed_result = f(perturbed_args..., perturbative=true)
    catch e
        if e isa MethodError
            perturbed_result = f(perturbed_args...)
        else
            rethrow(e)
        end
    end

    diff = StochasticAD.value.(perturbed_result .- primal_result)
    primal_result .+ stoch_trip.(0., 0.0, diff, epsilon)
end

A function that is used in conjunction with propagate_2 can then implement the keyword argument perturbative to indicate whether the function is calculated on a perturbative branch or not. Moreover, by "collapse" in the comment above sum(args), I mean the fact that if two independent stochastic triples are made to interact with one another, the perturbation of one of the two will be randomly removed and the infinitesimal probability associated with the other triple is increased correspondingly (i.e., only a single perturbation is tracked at a time).

There are a lot of problems with this function. To begin with, it doesn't have the nice integration with fmap that propagate has. It is also messing up typing of stochastic triples I think But there are also problems with its correctness, as discussed below.

For simple cases, this seems to fix some of the issues discussed in this issue. For instance, revisiting the simple function in the first post of this thread (adding a fixed stochastic triple to the argument), we find that if we define

c = stoch_trip(10.0, 0.0, 1000.0, 100.0)  # 10.0 + 0.0ε + (1000.0 with probability 100.0ε)
f(x) = x + c
x = stoch_trip(1.0, 0.0, 100.0, 1.0)  # 1.0 + 0.0ε + (100.0 with probability 1.0ε)

then either f(x) or x + c will return 11.0 + 0.0ε + (1000.0 with probability 101.0ε) (with dominant probability). Hence this implementation kinda works with closures over stochastic triples. However, there is a problem already with this simple example: after running f(x), the value of c will change to 10.0 + 0.0ε + (1000.0 with probability 101.0ε). In fact, each time f(x) is executed, the probability will increase with one ε. Of course, because of this, f(x) only gives the expected result the first time it is executed. This weird backreaction on c is a major problem, and I think it is due to the fact that I'm building a new stochastic triple from scratch in propagate_2, effectively adding extra probability mass to the system. I am not sure how the function would have to be altered to get a similar effect without the backreaction.

Now, the idea would be to use such an altered propagation function to fix the step game as follows:

function step_game_3(rand_var_1, rand_var_2, tolerance)
    value_1 = rand(rand_var_1)
    value_2 = rand(rand_var_2)
    recursive_update_step_game_2(value_1, value_2, rand_var_1, rand_var_2, tolerance)
end

function recursive_update_step_game_2(value_1, value_2, rand_var_1, rand_var_2, tolerance;
        perturbative=false)
    if abs(value_1 - value_2) <= tolerance
        return value_1, value_2
    end
    if value_1 < value_2
        value_1 += rand(rand_var_1)
    else
        value_2 += rand(rand_var_2)
    end

    if perturbative
        value_1 = StochasticAD.value(value_1)
        value_2 = StochasticAD.value(value_2)
        recursive_update_step_game_2(value_1, value_2, rand_var_1, rand_var_2, tolerance,
            perturbative=true)
    else
        recursive_update_step_game_2(value_1, value_2, rand_var_1, rand_var_2, tolerance)
    end
end 

function recursive_update_step_game_2(value_1::StochasticAD.StochasticTriple, value_2,
        rand_var_1, rand_var_2, tolerance)
    f = (x; perturbative) -> recursive_update_step_game_2(x, value_2, rand_var_1,
        rand_var_2, tolerance, perturbative=perturbative)
    propagate_2(f, value_1)
end

By implementing the perturbative keyword, which will be passed as true by propagate_2 whenever evaluating perturbative branches, we can avoid branching out of branches that are already perturbative. When we are on a perturbative branch, we ignore the perturbations in the values we find (these are second order after all) thereby avoiding another call to propagate_2. This allows us to run code like

p = stochastic_triple(0.1)
rand_var_1 = Geometric(p)
rand_var_2 = Geometric(0.1)
tolerance = 2.
step_game_3(rand_var_1, rand_var_2, tolerance)

The good news is that it runs just fine, gives stochastic triples as output, and is quick without running into any stack overflows. However, I have almost zero faith that the results are correct. An example output is

(49.0 + 0.0ε + (38.0 with probability 544.4444444444445ε), 51.0 + 0.0ε + (13.0 with probability 444.44444444444446ε))

While the primary values are nicely within the tolerance of 2 of each other, the same does not seem true for the perturbed values, which I think indicates something is seriously wrong. Also the infinitesimal probabilities are super large, while each sample taken from the geometric distribution only introduces a single ε. I suspect that this problem is closely related to the one described above, where each call of f(x) introduces more probability mass.

Thanks again, looking forward to your reply!

@GuusAvis
Copy link
Author

Thanks to some very useful suggestions from @gaurav-arya I was able to make the while loop of the step game work without any stack overflows. The trick consists of introducing a new variable, which I called finished, in the loop that is itself a stochastic triple, and that indicates per branch whether the process has finished or not (i.e., whether the two values lie within the tolerance). Using the function alltrue, which is a little bit hidden in the internals of StochasticAD, we can then use the finished variable to check whether all branches have terminated or whether some are still going, and control the while loop on that.

However, the original problem addressed in this issue still stands: propagate only works on functions that map normal numbers to normal numbers, not functions that map normal numbers to stochastic triples. Luckily, @gaurav-arya indicated that it may be quite easy to fix that and that there will be a small PR forthcoming, to which I am very much looking forward!

For completeness, let me here post my latest implementation of the step game.

using StochasticAD

"""
    update_values(value_1, rand_var_1, value_2, rand_var_2, tolerance)

Update the values of by sampling the random variables.

Note: this function may return stochastic triples even if `value_1` and `value_2` are normal
numbers, as the samples taken from `rand_var_1` and `rand_var_2` may be stochastic triples.
This is currently not correctly handled by `StochasticAD.propagate`.
"""
function update_values(value_1, rand_var_1, value_2, rand_var_2, tolerance)
    within_tolerance(value_1, value_2, tolerance) && return value_1, value_2
    if value_1 < value_2
        value_1 += rand(rand_var_1)
    else
        value_2 += rand(rand_var_2)
    end
    value_1, value_2
end
function update_values(value_1::StochasticAD.StochasticTriple,
        rand_var_1, value_2, rand_var_2, tolerance)
    f = v1 -> update_values(v1, rand_var_1, value_2, rand_var_2, tolerance)
    StochasticAD.propagate(f, value_1)
end
function update_values(value_1, rand_var_1, value_2::StochasticAD.StochasticTriple,
        rand_var_2, tolerance)
    f = v2 -> update_values(value_1, rand_var_1, v2, rand_var_2, tolerance)
    StochasticAD.propagate(f, value_2)
end
function update_values(value_1::StochasticAD.StochasticTriple,
        rand_var_1, value_2::StochasticAD.StochasticTriple, rand_var_2, tolerance)
    f = (v1, v2) -> update_values(v1, rand_var_1, v2, rand_var_2, tolerance)
    StochasticAD.propagate(f, value_1, value_2)
end

"""
    within_tolerance(value_1, value_2, tolerance)

Check if the values satisfy the tolerance level.
"""
within_tolerance(value_1, value_2, tolerance) = abs(value_1 - value_2) < tolerance
function within_tolerance(value_1::StochasticAD.StochasticTriple, value_2, tolerance)
    StochasticAD.propagate(x -> within_tolerance(x, value_2, tolerance), value_1)
end
function within_tolerance(value_1, value_2::StochasticAD.StochasticTriple, tolerance)
    StochasticAD.propagate(x -> within_tolerance(value_1, x, tolerance), value_2)
end
function within_tolerance(value_1::StochasticAD.StochasticTriple,
        value_2::StochasticAD.StochasticTriple, tolerance)
    StochasticAD.propagate((x, y) -> within_tolerance(x, y, tolerance), value_1, value_2)
end

"""
    istrue(x::Bool) =  x

Check if a number is unambiguously `true` in a way that works for stochastic triples.

If any of the branches of the stochastic triple are `false`, then the result is `false`.
"""
istrue(x::Bool) =  x
function istrue(x::StochasticAD.StochasticTriple)
    primary = isone(StochasticAD.value(x))
    perts = StochasticAD.alltrue(iszero, x.Δs)
    primary && perts
end

function step_game(rand_var_1, rand_var_2, tolerance)
    value_1 = rand(rand_var_1)
    value_2 = rand(rand_var_2)
    finished = false
    while !istrue(finished)
        value_1, value_2 = update_values(value_1, rand_var_1, value_2, rand_var_2, tolerance)
        finished = within_tolerance(value_1, value_2, tolerance)
    end
    value_1, value_2
end

To clarify the above: everything works fine now, except that the function update_values() does not update the values correctly because the perturbative parts of rand(rand_var_1) and rand(rand_var_2) are disregarded by propagate.

To test the correctness of the implementation, I set both the random variables to be geometric, fix one of the two parameters and vary the other. Then I look at what happens to the expected value of the absolute value of the difference between the two output values. That is, I am looking at the expected value of the random program

using Distributions

function absolute_difference(p)
    rand_var_1 = Geometric(p)
    rand_var_2 = Geometric(0.1)
    tolerance = 5.
    v1, v2 = step_game(rand_var_1, rand_var_2, tolerance)
    abs(v1 - v2)
end

as a function of p. That looks like this:
loop_fix_in_issue_exp
That looks nice and testable as any correct derivative estimate should go through zero around p = 0.15. So let's look at the derivative estimate:
loop_fix_in_issue_diff
That does not look right.

I will rerun the second plot as soon as there is a fix in propagate, if it looks correct I will be very contend!

@GuusAvis
Copy link
Author

GuusAvis commented Aug 5, 2024

Many thanks for that PR @gaurav-arya !

I have used it to rerun the derivative and get the following result:

derivative

That looks correct to me, this is great :).

I noticed though that the performance of the code is not great. I played around with it a bit and it often seems that the derivative estimate is slower by a factor ~600 compared to evaluating the primal value. Also the variance is kind of big (error bars denote the standard error, both the primal evaluation and the derivative were estimated using 1E6 samples).

Do you perhaps have any ideas how the call to propagate could be made more efficient, or the variance made smaller? It may be that the many branching points in the step game make it inherently difficult to differentiate, but if there are some tips and tricks to be found I would be very interested!

@GuusAvis
Copy link
Author

Happy to report that the fix also allowed me to get autodiff working for my actual use case! However also there I observe that runtime is often O(1E3) times worse for derivative evaluations compared to primal evaluations, and the variance in the derivative samples is really bad.

Seems that in that sense the step game is a good representative MWE for my real use case, so I'm hopeful that if things can be improved for the step game they will also be improved for my case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants