probcomp / Gen.jl

A general-purpose probabilistic programming system with programmable inference
https://gen.dev
Apache License 2.0
1.79k stars 160 forks source link

Named Trace causes derivatives not to be calculated? #104

Open wsmoses opened 5 years ago

wsmoses commented 5 years ago

Consider the following program that tries to find values of x1 such that it is close to zero. After running one round of map_optimize, we should find x1 moved closer to zero (as in the subsequent traces)

@gen function subtest()
    @trace(uniform(-3,3), (:x1))
end
@gen function test(thres)
    res = @trace(subtest())
    @trace(normal(res,thres), (:z))
end;

force=Gen.choicemap(:z=>0)
t,_ = Gen.generate(test, (.01,),force)
println(Gen.get_choices(t))

t2=map_optimize(t, Gen.select(:x1), max_step_size=0.00001)
println(Gen.get_choices(t2))
│
├── :z : 0.0
│
└── :x1 : 0.28988469883583434

│
├── :z : 0.0
│
└── :x1 : 0.2608962289522509

Now consider the following program that adds a scope for the traced call

@gen function subtest()
    @trace(uniform(-3,3), (:x1))
end
@gen function test(thres)
    res = @trace(subtest(),:n)
    @trace(normal(res,thres), (:z))
end;

force=Gen.choicemap(:z=>0)
t,_ = Gen.generate(test, (.01,),force)
println(Gen.get_choices(t))

t2=map_optimize(t, Gen.select(:n => :x1), max_step_size=0.00001)
println(Gen.get_choices(t2))

However, this doesn't have the trace change (and further inspecting in Gen -- the calculation of the derivative is zero now?)

│
├── :z : 0.0
│
└── :n
    │
    └── :x1 : -0.3862602351556288

│
├── :z : 0.0
│
└── :n
    │
    └── :x1 : -0.3862602351556288
alex-lew commented 5 years ago

Hi there! I don't think this is mentioned in the current tutorial material, but in order for gradients to flow through the return value of subtest, you'll need to annotate the function definition with (grad):

@gen (grad) function subtest()
    @trace(uniform(-3,3), :x1)
end

See here for more info: https://probcomp.github.io/Gen/dev/ref/modeling/#Differentiable-programming-1

The fact that this works in your first example (without an address for the call to subtest) seems to be a fluke of the current implementation -- thanks for pointing that out. What you're seeing here is that even though @trace(f()) and @trace(f(), :f) are very similar syntactically, they are somewhat different under the hood. For example, @trace(f()) does not work with arbitrary generative functions f: f must be a generative function also written in the @gen modeling language.

wsmoses commented 5 years ago

Thanks!

Unfortunately using grad doesn't completely work for some use cases as it's possible the return value of the function may be some sort of object (which doesn't work with ReverseDiff as it only accepts the following)

Closest candidates are:
  track(!Matched::Real, ::Array{ReverseDiff.AbstractInstruction,1}) at /Users/wmoses/.julia/packages/ReverseDiff/qmgw8/src/tracked.jl:381
  track(!Matched::AbstractArray, ::Array{ReverseDiff.AbstractInstruction,1}) at /Users/wmoses/.julia/packages/ReverseDiff/qmgw8/src/tracked.jl:383

Stacktrace:
 [1] track(::ChainBlock{1}, ::Array{ReverseDiff.AbstractInstruction,1}) at /Users/wmoses/.julia/packages/Gen/52u9K/src/backprop.jl:9
 [2] traceat(::Gen.GFBackpropTraceState, ::DynamicDSLFunction{Any}, ::Tuple{}, ::Tuple{Symbol,Int64}) at /Users/wmoses/.julia/packages/Gen/52u9K/src/dynamic/backprop.jl:311

For my particular use case I guess I could just trace without addresses (and use this fluke), but I figure there could be a way of engineering this that avoid the autodiff restriction (and perhaps still allows the grad annotation?).

marcoct commented 5 years ago

@wsmoses Thanks for your questions!

Can you tell us more about your use case? Do you need to differentiate through your own custom data type? If so, it won't work correctly out of the box, whether or not you use an address in the @trace expression. Does your custom data type have a constructor that accepts TrackedReal or TrackedArray values? If not, then the program should crash if there is indeed a need to backpropagate through your custom data type. If the program is not crashing, then perhaps differentiation through the custom data type (and the (grad) annotation) are not needed?

If you annotate a gen function with (grad) then it is indeed assumed that ReverseDiff can track its return value. If you don't annotate the function with (grad) then backpropagation will not occur through the function.

If you do need to differentiate through your own data type then options include (1) marshaling it into an array so that ReverseDiff can deal with it, or (2) extending ReverseDiff to handle this data type. There is ongoing work on improving automatic differentiation of general Julia code that works with custom struct data types (https://github.com/FluxML/Zygote.jl), but we have not yet integrated that into Gen.

wsmoses commented 5 years ago

Hey @marcoct

I'm trying to integrate Gen with a black box simulator to see if we can quickly beat some state of the art results. With a bit of engineering effort (both looking at the simulator and doing some modifications, as well as combing through Gen source code) I was able to get it to work with the unnamed trace (seemingly because when the scope is the same it effectively inlines it from the AD standpoint? <- a guess, haven't look too closely).

The datatype itself that is returned is effectively a special boxed array (or rather sequence of arrays), hence why it doesn't match a track call. However the AD is able to go through it fine (as it has tracked arrays/reals) so without the scope I seem to get the correct derivatives (and successfully do the sort of test above, but slightly more complex). Thus without (grad) and scope it looks like it works fine.

I'd be happy to go more in depth offline but it's research code for a project that isn't quite at a stage I'd want to make it public yet. My email is wmoses@mit.edu

marcoct commented 5 years ago

I was able to get it to work with the unnamed trace (seemingly because when the scope is the same it effectively inlines it from the AD standpoint? <- a guess, haven't look too closely).

Yes, it effectively inlines.

The datatype itself that is returned is effectively a special boxed array (or rather sequence of arrays), hence why it doesn't match a track call. However the AD is able to go through it fine (as it has tracked arrays/reals) so without the scope I seem to get the correct derivatives (and successfully do the sort of test above, but slightly more complex). Thus without (grad) and scope it looks like it works fine.

Ah, that makes sense.

FWIW, my guess is that if Gen continues to use tape-based reverse-mode AD, that we would want to handle this by making it very easy to implementing track for custom data types, so that the behavior is predictable between the named and un-named @trace calls, and so that the semantics of (grad) are consistently enforced. But that might be obviated by using source-to-source AD (like Zygote.jl) instead, I'll need to think about it more.

We should leave this issue open, and I'll give it an automatic differentiation label.