gaurav-arya / StochasticAD.jl

Research package for automatic differentiation of programs containing discrete randomness.
MIT License
199 stars 16 forks source link

Make propagate more like a monadic bind by supporting stochastic triple creating functions #132

Open gaurav-arya opened 1 month ago

gaurav-arya commented 1 month ago

x-ref #128. @GuusAvis example:

using StochasticAD
using Distributions

function f(value_1, value_2, rand_var)
    if value_1 < value_2
        return (value_1 + rand(rand_var), value_2)
    else
        return (value_1, value_2 + rand(rand_var))
    end
end

propagate_f(value_1, value_2, rand_var) = StochasticAD.propagate((v1, v2) -> f(v1, v2, rand_var), value_1, value_2)

f(value_1::StochasticTriple, value_2, rand_var) = propagate_f(value_1, value_2, rand_var)
f(value_1, value_2::StochasticTriple, rand_var) = propagate_f(value_1, value_2, rand_var)
f(value_1::StochasticTriple, value_2::StochasticTriple, rand_var) = propagate_f(value_1, value_2, rand_var)

function g(p)
    rand_var = Bernoulli(p)
    value_1 = 0
    value_2 = 2
    for i in 1:10
        value_1, value_2 = f(value_1, value_2, rand_var)
    end
    return value_1, value_2
end

@show g(0.5)
@show mean((sum(g(0.6)) - sum(g(0.5))) / 0.1 for i in 1:1000) # 9.59
@show mean(derivative_estimate(p -> sum(g(p)), 0.5) for i in 1:100) # 8.84

@GuusAvis let me know if you have any issues, and if things work out adding the above as a test to triples.jl would be most welcome:)

GuusAvis commented 1 month ago

Many thanks for this PR @gaurav-arya, I think the code is working great now (see also my comment in the issue where I shared a test).

The code you pushed broke an existing test, I managed to solve one problem but there appear to me more.. In one of the tests you are passing keep_deltas = Val{test_deltas} instead of keep_deltas = Val(test_deltas). I think this problem is very related to #126. I opened a PR into this branch to fix the issue, see #133. There are other issues also that I'm happy to help with but am unsure about (I'm not sure I entirely understand the dynamic of keep_deltas and keep_triples, was keep_triples required because of the other changes in this PR?).

Moreover, I have added another PR (#134) into this branch to add more tests. I have added a statistical test based on your code here, and another simpler test of adding two numbers together (more or less testing the original example I gave in #128 ). Let me know if you think this makes sense and feel free to suggest or make any chances.

gaurav-arya commented 1 month ago

Hey! Thank you so much for testing it:) I am going to have to wait until after my thesis deadline Friday midnight, but I'll circle circle back then! If you have an MWE of the 600x slowdown (which is definitely too slow) that would be great.

Edit: making a @profview of the MWE can also be very informative

GuusAvis commented 1 month ago

No worries!

MWE of the 600x slowdown is the step game as also posted in #128. Let me reproduce it here completely for clarify and consistency.

The code that I use for the step game:

using StochasticAD, Distributions

"""
    update_values(value_1, rand_var_1, value_2, rand_var_2, tolerance)

Update the values of by sampling the random variables.

Note: this function may return stochastic triples even if `value_1` and `value_2` are normal
numbers, as the samples taken from `rand_var_1` and `rand_var_2` may be stochastic triples.
This is currently not correctly handled by `StochasticAD.propagate`.
"""
function update_values(value_1, rand_var_1, value_2, rand_var_2, tolerance)
    within_tolerance(value_1, value_2, tolerance) && return value_1, value_2
    if value_1 < value_2
        value_1 += rand(rand_var_1)
    else
        value_2 += rand(rand_var_2)
    end
    value_1, value_2
end
function update_values(value_1::StochasticAD.StochasticTriple,
        rand_var_1, value_2, rand_var_2, tolerance)
    f = v1 -> update_values(v1, rand_var_1, value_2, rand_var_2, tolerance)
    StochasticAD.propagate(f, value_1)
end
function update_values(value_1, rand_var_1, value_2::StochasticAD.StochasticTriple,
        rand_var_2, tolerance)
    f = v2 -> update_values(value_1, rand_var_1, v2, rand_var_2, tolerance)
    StochasticAD.propagate(f, value_2)
end
function update_values(value_1::StochasticAD.StochasticTriple,
        rand_var_1, value_2::StochasticAD.StochasticTriple, rand_var_2, tolerance)
    f = (v1, v2) -> update_values(v1, rand_var_1, v2, rand_var_2, tolerance)
    StochasticAD.propagate(f, value_1, value_2)
end

"""
    within_tolerance(value_1, value_2, tolerance)

Check if the values satisfy the tolerance level.
"""
within_tolerance(value_1, value_2, tolerance) = abs(value_1 - value_2) < tolerance
function within_tolerance(value_1::StochasticAD.StochasticTriple, value_2, tolerance)
    StochasticAD.propagate(x -> within_tolerance(x, value_2, tolerance), value_1)
end
function within_tolerance(value_1, value_2::StochasticAD.StochasticTriple, tolerance)
    StochasticAD.propagate(x -> within_tolerance(value_1, x, tolerance), value_2)
end
function within_tolerance(value_1::StochasticAD.StochasticTriple,
        value_2::StochasticAD.StochasticTriple, tolerance)
    StochasticAD.propagate((x, y) -> within_tolerance(x, y, tolerance), value_1, value_2)
end

"""
    istrue(x::Bool) =  x

Check if a number is unambiguously `true` in a way that works for stochastic triples.

If any of the branches of the stochastic triple are `false`, then the result is `false`.
"""
istrue(x::Bool) =  x
function istrue(x::StochasticAD.StochasticTriple)
    primary = isone(StochasticAD.value(x))
    perts = StochasticAD.alltrue(iszero, x.Δs)
    primary && perts
end

function step_game(rand_var_1, rand_var_2, tolerance)
    value_1 = rand(rand_var_1)
    value_2 = rand(rand_var_2)
    finished = false
    while !istrue(finished)
        value_1, value_2 = update_values(value_1, rand_var_1, value_2, rand_var_2,
            tolerance)
        finished = within_tolerance(value_1, value_2, tolerance)
    end
    value_1, value_2
end

It looks like a bit much code for an MWE but a lot of it is just calling propagate.

Now, let's do some benchmarking! First without derivatives:

rand_var_1 = Geometric(0.1)
rand_var_2 = Geometric(0.1)
tolerance = 5.

@elapsed step_game(rand_var_1, rand_var_2, tolerance)  # ~ 5E-6 seconds
@allocated step_game(rand_var_1, rand_var_2, tolerance)  # always 32
@elapsed [step_game(rand_var_1, rand_var_2, tolerance) for _ in 1:1E7]  # ~ 1.25 seconds

We see that we can run 1E7 samples in about a second, not too bad. Allocations also looking good I think. Now if we make one of the random variables produce stochastic triples:

rand_var_1_triple = Geometric(stochastic_triple(0.01))
rand_var_2 = Geometric(0.01)
tolerance = 5.

@elapsed step_game(rand_var_1_triple, rand_var_2, tolerance)  # ~ 2E-3 seconds
@allocated step_game(rand_var_1_triple, rand_var_2, tolerance)  # random, ~ 4E5
@elapsed [step_game(rand_var_1_triple, rand_var_2, tolerance) for _ in 1:1E3]  # ~ 1.15 seconds
@profview [step_game(rand_var_1_triple, rand_var_2, tolerance) for _ in 1:1E3]  # lots of propagate

Now we can only run 1E3 samples in a second, suggesting an even bigger slowdown then I had initially reported. Also the number of allocations is all over the place. There are a lot of allocations (tens of thousands for a single run), and moreover it appears to be a random number while it was constant for the primal evaluation above.

I had indeed also used @profview, it suggests (unsurprisingly) that most time is spent inside of propagates. Underneath that are mostly mapping functions (structural_map etc) and it was not very easy for me to interpret it directly. For your reference, I will share the flamegraph I produced here (had to zip it because github doesn't like html, also attached screenshot for convenience). step_game_profile.zip Screenshot from 2024-08-09 20-20-30

GuusAvis commented 3 weeks ago

@gaurav-arya Did you have a chance to look at the slowdown yet?

gaurav-arya commented 3 weeks ago

Hi @GuusAvis -- not yet. Thank you for the ping -- I'll take a look this weekend!

gaurav-arya commented 2 weeks ago

Partial debug (made JET.jl happy on the code by tweaking a few things, and set the geometric parameter to 0.1 in both cases for consistency. but still a large performance gap, and profview is still pointing to runtime dispatches that I somehow cannot see on JET.jl or Cthulhu.jl...)

using StochasticAD, Distributions

"""
    update_values(value_1, rand_var_1, value_2, rand_var_2, tolerance)

Update the values of by sampling the random variables.

Note: this function may return stochastic triples even if `value_1` and `value_2` are normal
numbers, as the samples taken from `rand_var_1` and `rand_var_2` may be stochastic triples.
This is currently not correctly handled by `StochasticAD.propagate`.
"""
function update_values(value_1, rand_var_1, value_2, rand_var_2, tolerance)
    value_1_ret = value_1 + zero(rand(rand_var_1))
    value_2_ret = value_2 + zero(rand(rand_var_2))
    within_tolerance(value_1, value_2, tolerance) && return value_1_ret, value_2_ret
    if value_1 < value_2
        value_1_ret += rand(rand_var_1) 
    else
        value_2_ret += rand(rand_var_2) 
    end
    return value_1_ret, value_2_ret
end
function update_values(value_1::StochasticAD.StochasticTriple,
        rand_var_1, value_2, rand_var_2, tolerance)
    f = v1 -> update_values(v1, rand_var_1, value_2, rand_var_2, tolerance)
    StochasticAD.propagate(f, value_1)
end
function update_values(value_1, rand_var_1, value_2::StochasticAD.StochasticTriple,
        rand_var_2, tolerance)
    f = v2 -> update_values(value_1, rand_var_1, v2, rand_var_2, tolerance)
    StochasticAD.propagate(f, value_2)
end
function update_values(value_1::StochasticAD.StochasticTriple,
        rand_var_1, value_2::StochasticAD.StochasticTriple, rand_var_2, tolerance)
    f = (v1, v2) -> update_values(v1, rand_var_1, v2, rand_var_2, tolerance)
    StochasticAD.propagate(f, value_1, value_2)
end

"""
    within_tolerance(value_1, value_2, tolerance)

Check if the values satisfy the tolerance level.
"""
within_tolerance(value_1, value_2, tolerance) = abs(value_1 - value_2) < tolerance
function within_tolerance(value_1::StochasticAD.StochasticTriple, value_2, tolerance)
    StochasticAD.propagate(x -> within_tolerance(x, value_2, tolerance), value_1)
end
function within_tolerance(value_1, value_2::StochasticAD.StochasticTriple, tolerance)
    StochasticAD.propagate(x -> within_tolerance(value_1, x, tolerance), value_2)
end
function within_tolerance(value_1::StochasticAD.StochasticTriple,
        value_2::StochasticAD.StochasticTriple, tolerance)
    StochasticAD.propagate((x, y) -> within_tolerance(x, y, tolerance), value_1, value_2)
end

"""
    istrue(x::Bool) =  x

Check if a number is unambiguously `true` in a way that works for stochastic triples.

If any of the branches of the stochastic triple are `false`, then the result is `false`.
"""
istrue(x::Bool) =  x
function istrue(x::StochasticAD.StochasticTriple)
    primary = isone(StochasticAD.value(x))
    perts = StochasticAD.alltrue(iszero, x.Δs)
    primary && perts
end

function step_game(rand_var_1, rand_var_2, tolerance)
    value_1 = rand(rand_var_1)
    value_2 = rand(rand_var_2)
    finished = if (value_1 isa StochasticTriple) || (value_2 isa StochasticTriple)
        StochasticAD.propagate(_ -> false, value_1 + value_2)
    else
        false
    end
    while !istrue(finished)
        @time value_1, value_2 = update_values(value_1, rand_var_1, value_2, rand_var_2,
            tolerance)
        finished = within_tolerance(value_1, value_2, tolerance)
    end
    value_1, value_2
end

##

rand_var_1 = Geometric(0.1)
rand_var_2 = Geometric(0.1)
tolerance = 5.

step_game(rand_var_1, rand_var_2, tolerance) 
@elapsed step_game(rand_var_1, rand_var_2, tolerance)  # ~ 5E-6 seconds
@allocated step_game(rand_var_1, rand_var_2, tolerance)  # always 32
@elapsed [step_game(rand_var_1, rand_var_2, tolerance) for _ in 1:1E6]  # ~ 0.4 seconds

##

rand_var_1_triple = Geometric(stochastic_triple(0.1; backend = PrunedFIsBackend()))
rand_var_2 = Geometric(0.1)
tolerance = 5.

step_game(rand_var_1_triple, rand_var_2, tolerance) 
@elapsed step_game(rand_var_1_triple, rand_var_2, tolerance)  # ~ 2E-4 seconds
@allocated step_game(rand_var_1_triple, rand_var_2, tolerance)  # random, ~ 4E5
@elapsed [step_game(rand_var_1_triple, rand_var_2, tolerance) for _ in 1:1E4]  # ~ 1.3 seconds
@profview [step_game(rand_var_1_triple, rand_var_2, tolerance) for _ in 1:1E4]

##

let value_1 = rand(rand_var_1), value_2 = rand(rand_var_2)
    @time [update_values(value_1, rand_var_1, value_2, rand_var_2, tolerance) for i in 1:10000000]
    nothing
end

let value_1 = rand(rand_var_1_triple), value_2 = rand(rand_var_2)
    @time [update_values(value_1, rand_var_1_triple, value_2, rand_var_2, tolerance) for i in 1:100000]
    nothing
end

let value_1 = rand(rand_var_1_triple), value_2 = rand(rand_var_2)
    @profview [update_values(value_1, rand_var_1_triple, value_2, rand_var_2, tolerance) for i in 1:100000]
    nothing
end