Closed phipsgabler closed 5 years ago
Cassette.jl looks more promising to me. Unlike Flux.Tracker, it doesn't rely on method dispatch, and doesn't have gradient collection built in too much. There used to be a plan to reimplement ReverseDiff with it in Capstan.jl -- although that has, as far as I can see, been abandoned, the features are still there and very suitable. Especially, there's a tagging mechanism, alghough it isn't documented well.
I can do the following with relatively short code:
julia> fib(n) = n ≤ 1 ? 1 : fib(n-1) + fib(n-2)
fib (generic function with 1 method)
julia> DynamicComputationGraphs.track(fib, 3)
(3, (getfield getfield(DynamicComputationGraphs, Symbol("##2#3")){typeof(fib),Tuple{Int64}}(fib, (3,)) f)
(getfield getfield(DynamicComputationGraphs, Symbol("##2#3")){typeof(fib),Tuple{Int64}}(fib, (3,)) args)
(fib 3)
(<= 3 1)
(sle_int 3 1)
(- 3 1)
(sub_int 3 1)
(fib 2)
(<= 2 1)
(sle_int 2 1)
(- 2 1)
(sub_int 2 1)
(fib 1)
(<= 1 1)
(sle_int 1 1)
(- 2 2)
(sub_int 2 2)
(fib 0)
(<= 0 1)
(sle_int 0 1)
(+ 1 1)
(add_int 1 1)
(- 3 2)
(sub_int 3 2)
(fib 1)
(<= 1 1)
(sle_int 1 1)
(+ 2 1)
(add_int 2 1)
)
I think I don't have a strong opinion on using Flux.Tracker or not. I remember Cassette.jl has been around for quite some time but never managed to be in a "usable" state. If you feel that it is stable enough than we can use Cassette.jl instead of Tracker.
We (Turing folks) just discussed the future of AD's in Julia. It seems there is going to be a unifying interface defined by https://github.com/JuliaDiff/ChainRules.jl in the next month or so (which depends on Cassette). I think we should keep this in mind and maybe make DynamicComputationGraphs.jl use ChainRules.jl for AD related functionalities in the future.
But first lets have a prototype implementation of the discussed functionalities first.
Ok, I'll have a look at that; currently I am using DiffRules (with is nice, because it returns expressions, which I can directly splice into @eval
), but I'm happy to switch.
I think I have now a working version of both forward and backward mode for scalar functions using Cassette (a very inefficient one, though, I guess) in the cassette.jl file. Next I'm going to look at how two things: how to use tapes and multiple variables, and how to factor out the code of both into some kind of general "accumulating interpretation interface".
This stuff is now described better in the readme.
The graph building process should be able to:
Expr
or something equivalent, allowing to convert back)The backward information in the graph should be mutable, so that one can update subgraphs without full re-evaluation when changing parts of a model.
I think that essentially, this amounts to a pimped version of what
Tracker.Call
does: