JuliaDiff / ChainRules.jl

forward and reverse mode automatic differentiation primitives for Julia Base + StdLibs
Other
427 stars 85 forks source link

rrule for casting LinearAlgebra.QRCompactWYQ into a Matrix #516

Open rkube opened 2 years ago

rkube commented 2 years ago

I'm trying to work with the Q matrix from a qr-factorization within Zygote. In an incomplete QR factorization for m>=n, the QRCompactWYQMatrix Q has size=(m,m) but only the first n columns are relevant: https://docs.julialang.org/en/v1/stdlib/LinearAlgebra/#LinearAlgebra.qr

Using this in Zygote behaves as follows:

using LinearAlgebra
using Zygote
using Random

Random.seed!(1234)

V = rand(Float32, (6,4))
Q, _ = qr(V)

Zygote.gradient(A -> sum(A), Matrix(Q))
Zygote.gradient(A -> sum(A), Q)
Zygote.gradient(A -> sum(Matrix(A)), Q)

ERROR: LoadError: DimensionMismatch("variable with size(x) == (6, 6) cannot have a gradient with size(dx) == (6, 4)")
Stacktrace:
 [1] (::ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}})(dx::FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/Voykb/src/projection.jl:209
 [2] Array_pullback
   @ ~/.julia/packages/ChainRules/qd40H/src/rulesets/Base/array.jl:9 [inlined]
 [3] ZBack
   @ ~/.julia/packages/Zygote/nsu1Y/src/compiler/chainrules.jl:140 [inlined]
 [4] Pullback
   @ ~/source/gpuplayground/src/qr_project_mwe.jl:16 [inlined]
 [5] (::Zygote.var"#50#51"{typeof(∂(#5))})(Δ::Float32)
   @ Zygote ~/.julia/packages/Zygote/nsu1Y/src/compiler/interface.jl:41
 [6] gradient(f::Function, args::LinearAlgebra.QRCompactWYQ{Float32, Matrix{Float32}})
   @ Zygote ~/.julia/packages/Zygote/nsu1Y/src/compiler/interface.jl:76
 [7] top-level scope

The first two calls to gradient work fine. The last call fails because the incoming gradient is size=(6,4) but the matrix is of size=(6,6).

I tried modifying the rrule from here, https://github.com/JuliaDiff/ChainRules.jl/blob/master/src/rulesets/Base/array.jl line 7 like this:

function ChainRules.rrule(::typeof(Matrix), x::LinearAlgebra.QRCompactWYQ) 
  project_x = ProjectTo(x)
  Array_pullback(ȳ) = (NoTangent(), project_x(ȳ))
  return T(x), Array_pullback
end

but that didn't work. Is there a good way of making calls to Matrix(Q) in a backwards pass? @mcabbott

mcabbott commented 2 years ago

I think this rule for Array really doesn't expect that the sizes might not match. The two sizes are

size(Matrix(Q)) == size(Array{Float64}(Q)) == size(Q.factors) == size(V) == (6,4)

size(Q) == size(Q .+ 100) == size(hcat(Q)) == size(collect(Q)) == size(Q * collect(Q)) == (6,6)

i.e. every other operation I can think of obeys size(Q), it's just the Array conversion which is weird.

So I guess it needs a special rule, but what should this return? It could just me a matrix of the larger size, but perhaps it would more usefully be some Tangent{QRCompactWYQ}(factors = ..., T = ...)?

I think ::typeof(Matrix) won't work, BTW, it needs to be ::Type{...}. So the simplest idea is:

function ChainRules.rrule(::Type{T}, Q::LinearAlgebra.QRCompactWYQ) where {T<:Array} 
    T(Q), dy -> (NoTangent(), hcat(dy, falses(size(Q,1), size(Q,2)-size(Q.factors,2))))
end;
rkube commented 2 years ago

Cool, that works. Here is a gist that includes code from https://github.com/JuliaDiff/ChainRules.jl/pull/469 into a minimum working example: https://gist.github.com/rkube/b965267944115af7d13b3f00e7533572 This code gives the same results as comparable code for pytorch.

But the syntax to define the pullback for typecasting differs from what is explained in the chainrules.jl docs https://juliadiff.org/ChainRulesCore.jl/stable/

This is for a normal function

function foo(x) 
...
end

function rrule(::typeof(foo), args...; kwargs...)
    ...
    return y, pullback
end

And the syntax for typecasting isfunction rrule(::Type{T}...) where {T}

Why is there a difference? It would be great to have this mentioned in the docs.

mcabbott commented 2 years ago

Might indeed be worth mentioning in the docs, if you can figure out where.

Something similar applies without AD:

julia> Matrix(s::String) = fill(s, 2, 2);

julia> Matrix("zz")
2×2 Matrix{String}:
 "zz"  "zz"
 "zz"  "zz"

julia> Matrix{String}("zz")
ERROR: MethodError: no method matching Matrix{String}(::String)

julia> (::Type{T})(s::String) where {T<:Matrix} = fill(s, 3, 3);

julia> Matrix{String}("zz")
3×3 Matrix{String}:
 "zz"  "zz"  "zz"
 "zz"  "zz"  "zz"
 "zz"  "zz"  "zz"