dfdx / XGrad.jl

eXpression gradients in Julia
Other
3 stars 4 forks source link

Resolving MethodErrors #12

Open cscherrer opened 6 years ago

cscherrer commented 6 years ago

I'm trying to use XGrad for Soss, and running into some problems.

Starting with something like this:

f = quote 
    ℓ = 0.0
    μ = θ[1]
    ℓ += logpdf(Normal(0, 5), μ)
    σ = softplus(θ[2])
    ℓ += abs(σ - θ[2])
    ℓ += logpdf(Truncated(Cauchy(0, 3), 0, Inf), σ)
    for x = DATA
        ℓ += logpdf(Normal(μ, σ), x)
    end
    ℓ
end

I get

Main> xdiff(f, θ=θ, DATA=DATA)
ERROR: MethodError: no method matching parse!(::Espresso.ExGraph, ::Espresso.ExH{:+=})
Closest candidates are:
  parse!(::Espresso.ExGraph, ::Espresso.ExH{:tuple}) at /home/chad/.julia/v0.6/Espresso/src/exgraph.jl:275
  parse!(::Espresso.ExGraph, ::Espresso.ExH{:'}) at /home/chad/.julia/v0.6/Espresso/src/exgraph.jl:261
  parse!(::Espresso.ExGraph, ::Espresso.ExH{:.}) at /home/chad/.julia/v0.6/Espresso/src/exgraph.jl:244
  ...
Stacktrace:
 [1] parse!(::Espresso.ExGraph, ::Expr) at /home/chad/.julia/v0.6/Espresso/src/exgraph.jl:176
 [2] collect_to!(::Array{Symbol,1}, ::Base.Generator{Array{Any,1},Espresso.##95#96{Espresso.ExGraph}}, ::Int64, ::Int64) at ./array.jl:508
 [3] collect(::Base.Generator{Array{Any,1},Espresso.##95#96{Espresso.ExGraph}}) at ./array.jl:476
 [4] parse!(::Espresso.ExGraph, ::Espresso.ExH{:block}) at /home/chad/.julia/v0.6/Espresso/src/exgraph.jl:269
 [5] #ExGraph#85(::Bool, ::Dict{Any,Any}, ::Array{Any,1}, ::Type{T} where T, ::Expr) at /home/chad/.julia/v0.6/Espresso/src/exgraph.jl:26
 [6] (::Core.#kw#Type)(::Array{Any,1}, ::Type{Espresso.ExGraph}, ::Expr) at ./<missing>:0
 [7] #xdiff#29(::Dict{Any,Any}, ::Array{Any,1}, ::Function, ::Expr) at /home/chad/.julia/v0.6/XGrad/src/xdiff.jl:230
 [8] (::XGrad.#kw##xdiff)(::Array{Any,1}, ::XGrad.#xdiff, ::Expr) at ./<missing>:0
 [9] eval(::Module, ::Any) at ./boot.jl:235

I'm not sure how to teach it about +=, so I change to this:

f = quote 
    ℓ = 0.0
    μ = θ[1]
    ℓ = ℓ + logpdf(Normal(0, 5), μ)
    σ = softplus(θ[2])
    ℓ = ℓ + abs(σ - θ[2])
    ℓ = ℓ + logpdf(Truncated(Cauchy(0, 3), 0, Inf), σ)
    for x = DATA
        ℓ = ℓ + logpdf(Normal(μ, σ), x)
    end
    ℓ
end

That helps a little:

Main> xdiff(f, θ=θ, DATA=DATA)
ERROR: MethodError: no method matching function_name(::Type{Distributions.Normal})
Closest candidates are:
  function_name(::Function) at reflection.jl:861
Stacktrace:
 [1] canonical(::Module, ::Symbol) at /home/chad/.julia/v0.6/Espresso/src/utils.jl:169
 [2] parse!(::Espresso.ExGraph, ::Espresso.ExH{:call}) at /home/chad/.julia/v0.6/Espresso/src/exgraph.jl:229
 [3] parse!(::Espresso.ExGraph, ::Expr) at /home/chad/.julia/v0.6/Espresso/src/exgraph.jl:176
 [4] collect(::Base.Generator{Array{Any,1},Espresso.##91#92{Espresso.ExGraph}}) at ./array.jl:475
 [5] parse!(::Espresso.ExGraph, ::Espresso.ExH{:call}) at /home/chad/.julia/v0.6/Espresso/src/exgraph.jl:232
 [6] parse!(::Espresso.ExGraph, ::Expr) at /home/chad/.julia/v0.6/Espresso/src/exgraph.jl:176
 [7] collect_to!(::Array{Symbol,1}, ::Base.Generator{Array{Any,1},Espresso.##91#92{Espresso.ExGraph}}, ::Int64, ::Int64) at ./array.jl:508
 [8] collect(::Base.Generator{Array{Any,1},Espresso.##91#92{Espresso.ExGraph}}) at ./array.jl:476
 [9] parse!(::Espresso.ExGraph, ::Espresso.ExH{:call}) at /home/chad/.julia/v0.6/Espresso/src/exgraph.jl:232
 [10] parse!(::Espresso.ExGraph, ::Expr) at /home/chad/.julia/v0.6/Espresso/src/exgraph.jl:176
 [11] parse!(::Espresso.ExGraph, ::Espresso.ExH{:(=)}) at /home/chad/.julia/v0.6/Espresso/src/exgraph.jl:202
 [12] parse!(::Espresso.ExGraph, ::Expr) at /home/chad/.julia/v0.6/Espresso/src/exgraph.jl:176
 [13] collect_to!(::Array{Symbol,1}, ::Base.Generator{Array{Any,1},Espresso.##95#96{Espresso.ExGraph}}, ::Int64, ::Int64) at ./array.jl:508
 [14] collect(::Base.Generator{Array{Any,1},Espresso.##95#96{Espresso.ExGraph}}) at ./array.jl:476
 [15] parse!(::Espresso.ExGraph, ::Espresso.ExH{:block}) at /home/chad/.julia/v0.6/Espresso/src/exgraph.jl:269
 [16] #ExGraph#85(::Bool, ::Dict{Any,Any}, ::Array{Any,1}, ::Type{T} where T, ::Expr) at /home/chad/.julia/v0.6/Espresso/src/exgraph.jl:26
 [17] (::Core.#kw#Type)(::Array{Any,1}, ::Type{Espresso.ExGraph}, ::Expr) at ./<missing>:0
 [18] #xdiff#29(::Dict{Any,Any}, ::Array{Any,1}, ::Function, ::Expr) at /home/chad/.julia/v0.6/XGrad/src/xdiff.jl:230
 [19] (::XGrad.#kw##xdiff)(::Array{Any,1}, ::XGrad.#xdiff, ::Expr) at ./<missing>:0
 [20] eval(::Module, ::Any) at ./boot.jl:235

I thought maybe I could use normlogpdf instead, but that's not any better:

Main> xdiff(:(normlogpdf(μ,σ,x)),μ=0.0,σ=1.0, x=0.1)
ERROR: LHS of normlogpdf(μ::Real, σ::Real, x::Number) = normlogpdf(zval(μ, σ, x)) - log(σ) is neither variable, nor tuple
Stacktrace:
 [1] parse!(::Espresso.ExGraph, ::Espresso.ExH{:(=)}) at /home/chad/.julia/v0.6/Espresso/src/exgraph.jl:213
 [2] #ExGraph#85(::Bool, ::Dict{Any,Any}, ::Array{Any,1}, ::Type{T} where T, ::Expr) at /home/chad/.julia/v0.6/Espresso/src/exgraph.jl:26
 [3] (::Core.#kw#Type)(::Array{Any,1}, ::Type{Espresso.ExGraph}, ::Expr) at ./<missing>:0
 [4] make_subgraph(::Espresso.ExGraph, ::Espresso.ExNode{:call}) at /home/chad/.julia/v0.6/Espresso/src/graph_utils.jl:117
 [5] #150 at ./<missing>:0 [inlined]
 [6] next at ./generator.jl:45 [inlined]
 [7] all(::Base.##179#181, ::Base.Generator{Set{Symbol},Espresso.##150#151{Espresso.ExGraph}}) at ./reduce.jl:598
 [8] Dict(::Base.Generator{Set{Symbol},Espresso.##150#151{Espresso.ExGraph}}) at ./dict.jl:147
 [9] inline_nodes at /home/chad/.julia/v0.6/Espresso/src/graph_utils.jl:144 [inlined]
 [10] forward_pass!(::Espresso.ExGraph) at /home/chad/.julia/v0.6/XGrad/src/xdiff.jl:49
 [11] _xdiff(::Espresso.ExGraph) at /home/chad/.julia/v0.6/XGrad/src/xdiff.jl:205
 [12] #xdiff#29(::Dict{Any,Any}, ::Array{Any,1}, ::Function, ::Expr) at /home/chad/.julia/v0.6/XGrad/src/xdiff.jl:231
 [13] (::XGrad.#kw##xdiff)(::Array{Any,1}, ::XGrad.#xdiff, ::Expr) at ./<missing>:0
 [14] eval(::Module, ::Any) at ./boot.jl:235
dfdx commented 6 years ago

Let me explain these issues and how to fix them.

ERROR: MethodError: no method matching parse!(::Espresso.ExGraph, ::Espresso.ExH{:+=})

As you correctly assumed, XGrad (or, more specifically, Espresso) can't parse +=. But there's a serious reason for this: Espresso assumes expressions to be mathematically correct, and in math there can't be any mutation. So rewriting it as:

ℓ = 0.0
ℓ = ℓ + logpdf(Normal(0, 5), μ)

doesn't help since you still modify .

The easiest way to fix it is to use new names for variables, e.g. something like:

μ = θ[1]
ℓ = logpdf(Normal(0, 5), μ)
σ = softplus(θ[2])
ℓ2 = ℓ + abs(σ - θ[2])
ℓ3 = ℓ2 + logpdf(Truncated(Cauchy(0, 3), 0, Inf), σ)
...

I thought about automatic renaming so that you don't have to do it manually, but implications are unclear at the moment.

For the same reason there's still no first-class support for loops in Espresso - without mutating variables loops are nearly useless (I think we will add support for them anyway, but design depends on use cases, and I haven't found a good one yet). Your example may be rewritten (automatically, so that the user doesn't have to worry about it) into something like:

d = ...
ℓ_full = sum(logpdf.(d, x))

I'm also worried about distribution types like Normal or Cauchy - although it's convenient to use a custom type to represent distribution in Julia, it's not something automatic differentiation can work with. For example:

N = Normal(0, 5)
ℓ = logpdf(N, μ)

If is a number, dℓ/dμ is also a number. But what is dℓ/dN? In some cases, XGrad can dive into a structure of Normal type and correctly derive derivatives of its parameters, but we might need to use a more functional approach (e.g. normlogpdf). Anyway, let's see how it works.


I propose to go in 2 steps:

  1. Create an expression that is actually differentiable, i.e. can be parsed and includes only defined functions with known derivatives.
  2. Write transformation that takes user input in a convenient form (e.g. with loops, variable mutations or whatever) and produces an expression for differentiation.

Does it sound reasonable to you?

cscherrer commented 6 years ago

Thanks, yes this sounds manageable. One of the big benefits of source transformation is having so much flexibility in what I pass to the next stage.

I know Julia does some kind of SSA transform -- I think it might be required by LLVM -- but maybe this isn't available to the user. But it should be easy enough to implement.

And I can rewrite to normlogpdf etc, but that by itself doesn't work either. Did you see the last code above?

dfdx commented 6 years ago

I know Julia does some kind of SSA transform -- I think it might be required by LLVM -- but maybe this isn't available to the user. But it should be easy enough to implement.

Yes and yes. I'll take a look at automatic renaming over the weekend.

And I can rewrite to normlogpdf etc, but that by itself doesn't work either. Did you see the last code above?

Ah, I didn't realize it was a self-containing example. Let me first explain why this happens because you may encounter other errors like this.

As I've mentioned somewhere else, xdiff works by rewriting expressions according to known derivative rules. If there's no rule for a function (which is the case with normlogpdf), xdiff tries to find and analyze its code. Unfortunately, code extraction is non-trivial in Julia and currently has some limitations. normlogpdf is defined in StatsFuns.jl and hits one of these limitations.

It's easy to fix it by defining a differentiation rule for normlogpdf. Thanks to code transformation approach, you don't have to derive it yourself, but can use XGrad itself. Here's the derivative (see below for how I made it):

function ∇normlogpdf(ds, μ, σ, x)
    dtmp687!dtmp687 = ds
    normlogpdf_tmp702_708 = 2
    dtmp687!dnormlogpdf_tmp690_694 = -dtmp687!dtmp687
    dtmp687!dσ__1 = dtmp687!dnormlogpdf_tmp690_694 / σ
    tmp723 = σ * σ
    tmp714 = 2.0
    normlogpdf_tmp690_694 = log(σ)
    normlogpdf_tmp699_705 = 1.8378770664093456
    zval_tmp695_697 = x - μ
    tmp721 = -zval_tmp695_697
    normlogpdf_tmp688_692 = zval_tmp695_697 / σ
    tmp717 = sign(normlogpdf_tmp688_692)
    tmp715 = abs(normlogpdf_tmp688_692)
    tmp716 = tmp714 .* tmp715
    tmp718 = tmp716 .* tmp717
    dtmp687!dnormlogpdf_tmp688_692 = tmp718 .* normlogpdf_tmp688_692
    dtmp687!dx = dtmp687!dnormlogpdf_tmp688_692 / σ
    dtmp687!dzval_tmp695_697 = dtmp687!dnormlogpdf_tmp688_692 / σ
    dtmp687!dμ = -dtmp687!dzval_tmp695_697
    tmp722 = tmp721 * dtmp687!dnormlogpdf_tmp688_692
    dtmp687!dσ__2 = tmp722 / tmp723
    dtmp687!dσ = dtmp687!dσ__1 .+ dtmp687!dσ__2
    normlogpdf_tmp698_704 = abs2(normlogpdf_tmp688_692)
    normlogpdf_tmp700_706 = normlogpdf_tmp698_704 + normlogpdf_tmp699_705
    normlogpdf_tmp701_707 = -normlogpdf_tmp700_706
    normlogpdf_tmp689_693 = normlogpdf_tmp701_707 / normlogpdf_tmp702_708
    tmp687 = normlogpdf_tmp689_693 - normlogpdf_tmp690_694
    tmp727 = (tmp687, dtmp687!dμ, dtmp687!dσ, dtmp687!dx)
end

where ds stands for a derivative of final variable w.r.t. to the output of this function. Then we can define rules for all 3 variables as:

@diffrule normlogpdf(μ, σ, x) μ ∇normlogpdf(ds, μ, σ, x)[2]
@diffrule normlogpdf(μ, σ, x) σ ∇normlogpdf(ds, μ, σ, x)[3]
@diffrule normlogpdf(μ, σ, x) x ∇normlogpdf(ds, μ, σ, x)[4]

and voilà!

julia> xdiff(:(normlogpdf(μ,σ,x)),μ=0.0,σ=1.0, x=0.1)
quote
    dtmp661!dtmp661 = 1.0
    tmp667 = ∇normlogpdf(dtmp661!dtmp661, μ, σ, x)
    dtmp661!dx = tmp667[4]
    tmp665 = ∇normlogpdf(dtmp661!dtmp661, μ, σ, x)
    tmp661 = normlogpdf(μ, σ, x)
    dtmp661!dσ = tmp665[3]
    tmp663 = ∇normlogpdf(dtmp661!dtmp661, μ, σ, x)
    dtmp661!dμ = tmp663[2]
    tmp669 = (tmp661, dtmp661!dμ, dtmp661!dσ, dtmp661!dx)
end

This may look like a set of hacks and monkey patching, but normally (e.g. in machine learning) the workflow is a bit different: you define a set of primitives like normlogprob and rules for them, and then everybody else just uses them. Adding new primitives isn't always trivial, but it's much easier than in many other languages / platforms (e.g. see Adding a New Op in TensorFlow).


Here's how I derived code for ∇normlogpdf (you may need latest master of Espresso and XGrad to repeat it). First, I copied definitions of dependent functions from StatsFuns and added empty lines to let underlying libraries correctly find the code:

zval(μ::Real, σ::Real, x::Number) = (x - μ) / σ

normlogpdf(z::Number) = -(abs2(z) + 1.8378770664093454836)/2

normlogpdf(μ::Real, σ::Real, x::Number) = normlogpdf(zval(μ, σ, x)) - log(σ)

Then I ran xdiff the usual way:

xdiff(:(normlogpdf(μ,σ,x)),μ=0.0,σ=1.0, x=0.1)

and got expression:

    dtmp687!dtmp687 = 1.0
    normlogpdf_tmp702_708 = 2
    dtmp687!dnormlogpdf_tmp690_694 = -dtmp687!dtmp687
    dtmp687!dσ__1 = dtmp687!dnormlogpdf_tmp690_694 / σ
    tmp723 = σ * σ
    tmp714 = 2.0
    normlogpdf_tmp690_694 = log(σ)
    normlogpdf_tmp699_705 = 1.8378770664093456
    zval_tmp695_697 = x - μ
    tmp721 = -zval_tmp695_697
    normlogpdf_tmp688_692 = zval_tmp695_697 / σ
    tmp717 = sign(normlogpdf_tmp688_692)
    tmp715 = abs(normlogpdf_tmp688_692)
    tmp716 = tmp714 .* tmp715
    tmp718 = tmp716 .* tmp717
    dtmp687!dnormlogpdf_tmp688_692 = tmp718 .* normlogpdf_tmp688_692
    dtmp687!dx = dtmp687!dnormlogpdf_tmp688_692 / σ
    dtmp687!dzval_tmp695_697 = dtmp687!dnormlogpdf_tmp688_692 / σ
    dtmp687!dμ = -dtmp687!dzval_tmp695_697
    tmp722 = tmp721 * dtmp687!dnormlogpdf_tmp688_692
    dtmp687!dσ__2 = tmp722 / tmp723
    dtmp687!dσ = dtmp687!dσ__1 .+ dtmp687!dσ__2
    normlogpdf_tmp698_704 = abs2(normlogpdf_tmp688_692)
    normlogpdf_tmp700_706 = normlogpdf_tmp698_704 + normlogpdf_tmp699_705
    normlogpdf_tmp701_707 = -normlogpdf_tmp700_706
    normlogpdf_tmp689_693 = normlogpdf_tmp701_707 / normlogpdf_tmp702_708
    tmp687 = normlogpdf_tmp689_693 - normlogpdf_tmp690_694
    tmp727 = (tmp687, dtmp687!dμ, dtmp687!dσ, dtmp687!dx)

Finally, we replace the first term (dtmp687!dtmp687) - the starting point of the derivation - with special symbol ds and wrap everything into a function:

function ∇normlogpdf(ds, μ, σ, x)
    dtmp687!dtmp687 = ds
    ...
end

If this is still black magic for you, but you need some more rules, don't hesitate to simply ping me :)

cscherrer commented 6 years ago

I know Julia does some kind of SSA transform -- I think it might be required by LLVM -- but maybe this isn't available to the user. But it should be easy enough to implement.

Yes and yes. I'll take a look at automatic renaming over the weekend.

Oh, I meant for me to rewrite it that way. But it would be better if this can go into XGrad directly - should make it more broadly usable anyway. Thanks!

A couple of other questions. You write

normlogpdf(z::Number) = -(abs2(z) + 1.8378770664093454836)/2

instead of using log2π directly. Is this a general limitation - only literals are known to be constant?

Also, for this bit:

@diffrule normlogpdf(μ, σ, x) μ ∇normlogpdf(ds, μ, σ, x)[2]
@diffrule normlogpdf(μ, σ, x) σ ∇normlogpdf(ds, μ, σ, x)[3]
@diffrule normlogpdf(μ, σ, x) x ∇normlogpdf(ds, μ, σ, x)[4]

Does this mean the code for ∇normlogpdf will be executed three times? Oh, and is there a way to generate code that will efficiently compute both the value and the gradient?

Thanks for all your help with this :)

dfdx commented 6 years ago

Is this a general limitation - only literals are known to be constant?

I hoped to fix it before you would notice :) It's a bug. I fixed it in the past, but it snuck back. I added and issue for it.

Does this mean the code for ∇normlogpdf will be executed three times?

In this specific case - yes, but you shouldn't worry about it right now. I simplify things a bit for now to make cognitive load smaller and snippets more understandable, but in general it's very easy to make XGrad optimize your code (at least eliminate common subexpressions).

Oh, and is there a way to generate code that will efficiently compute both the value and the gradient?

It's the default behavior, e.g. in the generated code above you can see:

tmp727 = (tmp687, dtmp687!dμ, dtmp687!dσ, dtmp687!dx)

tmp687 is a (generated) name of resulting variable and 3 other elements are derivatives of that result w.r.t. each of input parameters.

dfdx commented 6 years ago

I added and issue for it.

Fixed it, tested with constant:

const log2π = 1.8378770664093454836

Checkout master of Espresso and XGrad to try it out.

cscherrer commented 6 years ago

Great, thank you Andrei!

On Wed, Jan 3, 2018 at 3:20 PM Andrei Zhabinski notifications@github.com wrote:

I added and issue for it.

Fixed it, tested with constant:

const log2π = 1.8378770664093454836

Checkout master of Espresso and XGrad to try it out.

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/dfdx/XGrad.jl/issues/12#issuecomment-355155817, or mute the thread https://github.com/notifications/unsubscribe-auth/ABISwRX6Mtd-PangHlS17KF-mzXIKaf9ks5tHAtNgaJpZM4RRRbJ .

dfdx commented 6 years ago

Automatic variable renaming is in place. Example:

using Espresso

ex = quote
    x = 0
    x = x + a
    x = x * b
    y = x * x
end

to_expr(ExGraph(ex))

gives:

quote
    x = 0
    x2 = x + a
    x3 = x2 * b
    y = x3 * x3
end

Things like x += 1 aren't supported, though. Maybe I'll rewrite them to x = x + 1 before parsing, but it's hard to predict all possible use cases, so I have doubts. However, in your restricted use case you can easily rewrite them as:

using Espresso

rules = [:(_x += _y) => :(_x = _x + _y),
             :(_x *= _y) => :(_x = _x * _y)]
ex = quote
    a = 0
    a += 1
    a *= a
end
rewrite_all(ex, rules)

which gives:

quote
    a = 0
    a = a + 1
    a = a * a
end
cscherrer commented 6 years ago

Oh that's great! Easy for me to replace them all before calling it. Thanks!

On Fri, Jan 5, 2018 at 5:07 PM Andrei Zhabinski notifications@github.com wrote:

Automatic variable renaming is in place. Example:

using Espresso

ex = quote x = 0 x = x + a x = x b y = x x end

to_expr(ExGraph(ex))

gives:

quote x = 0 x2 = x + a x3 = x2 b y = x3 x3 end

Things like x += 1 aren't supported, though. Maybe I'll rewrite them to x = x + 1 before parsing, but it's hard to predict all possible use cases, so I have doubts. However, in your restricted use case you can easily rewrite them as:

using Espresso

rules = [:(_x += _y) => :(_x = _x + _y), :(_x = _y) => :(_x = _x _y)] ex = quote a = 0 a += 1 a *= a end rewrite_all(ex, rules)

which gives:

quote a = 0 a = a + 1 a = a * a end

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/dfdx/XGrad.jl/issues/12#issuecomment-355710156, or mute the thread https://github.com/notifications/unsubscribe-auth/ABISwXIxKeATKhULEeYXmLj7bwsCjuBAks5tHsdWgaJpZM4RRRbJ .