dfdx / XGrad.jl

eXpression gradients in Julia
Other
3 stars 4 forks source link

Why we stopped developing this approach? #21

Closed Sixzero closed 3 years ago

Sixzero commented 3 years ago

Hey guys! I was thinking about a solution exactly like this to create an autodiff solution. I don't know why we stopped developing this approach, and thought the easiest is to ask why it is not developed anymore.

P.S: As for me, I don't see any blocking problem with this approach and it should provide same or even better speed then competing solutions.

dfdx commented 3 years ago

Well, the answer actually depends on what you mean by "solution exactly like this".

Any reverse-mode AD methods consists of 2 steps - forward and reverse pass. Sometimes they are combined as in case with Zygote, but all packages I worked on keep them separate. The reverse pass in most cases is the same - take a computational graph (e.g. Tape) from the forward pass and add differentiation nodes to it. But the forward pass - i.e. building the computational graph - may be quite different.

XGrad's approach is to build the computational graph from the source code. It's very similar to what tangent does for Python. But there's a caveat - Julia code is much harder to analyze statically. For example, if you see in Python something like:

from A import foo

foo(x, y)

You just need to go to module A and find the definition of foo() to infer it's code. But in Julia with its multimethods actual implementation of foo() depends on types of its arguments x and y. And to find out their types, the easiest approach is just to execute all the code line by line. This is what XGrad does, however note that this is no more source-to-source, but rather a kind of tracing.

Tracing the source code directly is hard. Even if you manage to find the exact source code of a function in question (which fails often), there are dozens of ways to write the same thing. Consider the following equivalent notations which a good AD implementation should take into account:

x = ...
for i=1:100
    x += 1
end
x = ...
for i=1:100
    x = x + 1
end
x = ...
i = 1
while i <= 100
    x += 1
    i += 1
end
x = ...
i = 1
while true
    x += 1
    i += 1
    if i == 100
        break
    end 
end
x = ...
i = 1
while true
    x += 1
    i += 1
    i == 100 ? break : continue
end

All of these notations do the same thing, but have completely different ASTs. In general, Julia AST is quite rich, unless you write a macro with a fixed expression structure, it's no fun at all to work with all its peculiarities.

But at the same time, Julia has much simpler and relatively stable intermediate representation (IR) for code. Using tools like Cassette and IRTools protects you from IR API changes and gives you a solid foundation for code tracing and transformation, keeping most of what you would get with AST. That's I migrated from pseudo-source-to-source tracing in XGrad to the true tracing in Yota. I did this transition a few years ago and so far I'm 100% happy with this decision.

Please let me know if this answers your question or if you have any follow-up questions.

Sixzero commented 3 years ago

I like how clear you state things.

On the other hand in the last month I tried to create an IR base source-to-source AD tool, and I can say, it is not that easy, IR is less human readable then julia AST code.

I created an MVP AST AD tool and I feel things are pretty straight forward to do and work in a short amount of time, and everything seems not too far fetched.

Better indexing and memory usage comes really easy with transpiling the AST to reverse AST.

Yota is a great library.

Funny thing is that, I would have chose the xgrad name for my library too. I also use the "x" letter to name things a little different and distinguishable, since it is not widely used as a starting character and can be pressed easily on the keyboard. :) Also I am curious who you are, since you created this tool >2 years ago, so you are probably at least two years before me in progression.

dfdx commented 3 years ago

It's great to hear that you have progress with the AST-level AD tool! A couple more things that made me a lot of troubles in the past:

Function code extraction. Example:

g(x) = x + 1

function f(x)
   g(x) * g(x)
end

If g() is not a primitive, you need to trace through it. To do so, you first need to understand the types of its arguments, and then find the correct method code. Usually, Julia assigns some line information to functions, but it's not always accurate. If you want to go serious about it, I recommend checking that you can reliably extract code from:

Code evaluation scope

To find the correct method, you need to know function argument types. The easiest way to find the types is to evaluate everything before this line. You must be very careful to not pollute the global state with variables you create and to evaluate expressions in the right context, especially modules. E.g. here's a dumb mistake:

julia> x = 3
3

julia> Base.eval(:x)
ERROR: UndefVarError: x not defined
Stacktrace:
 [1] top-level scope
 [2] eval at ./boot.jl:331 [inlined]
 [3] eval(::Symbol) at ./Base.jl:38
 [4] top-level scope at REPL[6]:1

julia> Base.eval(Main, :x)
3

Non-function-call constructs

When you work with simple function calls, e.g. Expr(:call, f, args...), things are pretty simple. But real code also contains:

Anyway, good luck with it and let me know if I can be helpful for you!


Regarding the name, "x" is a reference to the word "eXpression", i.e. "x grad == eXpression GRADients". Similarly, I have another funny-named library Espresso.jl for expression transformation that you may find useful (see also MacroTools.jl).