TuringLang / IRTracker.jl

Dynamically track IR as a graph, using source transformations
31 stars 5 forks source link

Planned stuff #1

Closed phipsgabler closed 5 years ago

phipsgabler commented 5 years ago

The graph building process should be able to:

The backward information in the graph should be mutable, so that one can update subgraphs without full re-evaluation when changing parts of a model.

I think that essentially, this amounts to a pimped version of what Tracker.Call does:

> f(x, y) = x + y
> g(x, y) = 2f(x, y)
> y, graph = forward(g, 2, 3)
10, <graph stuff>

> backward(graph, 1)
(2.0, 2.0)

> convert(Expr, graph)
Expr
  head: Symbol call
  args: Array{Any}((3,))
    1: Symbol *
    2: Int64 2
    3: Expr
      head: Symbol call
      args: Array{Any}((3,))
        1: Symbol +
        2: Int64 2
        3: Int64 3

> graph  # unsure about this?
DynamicComputationGraph
  info: FunCall(Symbol("##g#359"))
  head: Symbol call
  args: Array{Any}((3,))
    1: Symbol *
    2: Int64 2
    3: DynamicComputationGraph
      info: FunCall(Symbol("##f#358"))
      head: Symbol call
      args: Array{Any}((3,))
        1: Symbol +
        2: Int64 2
        3: Int64 3
phipsgabler commented 5 years ago

Cassette.jl looks more promising to me. Unlike Flux.Tracker, it doesn't rely on method dispatch, and doesn't have gradient collection built in too much. There used to be a plan to reimplement ReverseDiff with it in Capstan.jl -- although that has, as far as I can see, been abandoned, the features are still there and very suitable. Especially, there's a tagging mechanism, alghough it isn't documented well.

I can do the following with relatively short code:

julia> fib(n) = n ≤ 1 ? 1 : fib(n-1) + fib(n-2)
fib (generic function with 1 method)

julia> DynamicComputationGraphs.track(fib, 3)
(3, (getfield getfield(DynamicComputationGraphs, Symbol("##2#3")){typeof(fib),Tuple{Int64}}(fib, (3,)) f)
  (getfield getfield(DynamicComputationGraphs, Symbol("##2#3")){typeof(fib),Tuple{Int64}}(fib, (3,)) args)
  (fib 3)
    (<= 3 1)
      (sle_int 3 1)
    (- 3 1)
      (sub_int 3 1)
    (fib 2)
      (<= 2 1)
        (sle_int 2 1)
      (- 2 1)
        (sub_int 2 1)
      (fib 1)
        (<= 1 1)
          (sle_int 1 1)
      (- 2 2)
        (sub_int 2 2)
      (fib 0)
        (<= 0 1)
          (sle_int 0 1)
      (+ 1 1)
        (add_int 1 1)
    (- 3 2)
      (sub_int 3 2)
    (fib 1)
      (<= 1 1)
        (sle_int 1 1)
    (+ 2 1)
      (add_int 2 1)
)
trappmartin commented 5 years ago

I think I don't have a strong opinion on using Flux.Tracker or not. I remember Cassette.jl has been around for quite some time but never managed to be in a "usable" state. If you feel that it is stable enough than we can use Cassette.jl instead of Tracker.

trappmartin commented 5 years ago

We (Turing folks) just discussed the future of AD's in Julia. It seems there is going to be a unifying interface defined by https://github.com/JuliaDiff/ChainRules.jl in the next month or so (which depends on Cassette). I think we should keep this in mind and maybe make DynamicComputationGraphs.jl use ChainRules.jl for AD related functionalities in the future.

But first lets have a prototype implementation of the discussed functionalities first.

phipsgabler commented 5 years ago

Ok, I'll have a look at that; currently I am using DiffRules (with is nice, because it returns expressions, which I can directly splice into @eval), but I'm happy to switch.

I think I have now a working version of both forward and backward mode for scalar functions using Cassette (a very inefficient one, though, I guess) in the cassette.jl file. Next I'm going to look at how two things: how to use tapes and multiple variables, and how to factor out the code of both into some kind of general "accumulating interpretation interface".

phipsgabler commented 5 years ago

This stuff is now described better in the readme.