Open rkube opened 2 years ago
I think this rule for Array
really doesn't expect that the sizes might not match. The two sizes are
size(Matrix(Q)) == size(Array{Float64}(Q)) == size(Q.factors) == size(V) == (6,4)
size(Q) == size(Q .+ 100) == size(hcat(Q)) == size(collect(Q)) == size(Q * collect(Q)) == (6,6)
i.e. every other operation I can think of obeys size(Q)
, it's just the Array
conversion which is weird.
So I guess it needs a special rule, but what should this return? It could just me a matrix of the larger size, but perhaps it would more usefully be some Tangent{QRCompactWYQ}(factors = ..., T = ...)
?
I think ::typeof(Matrix)
won't work, BTW, it needs to be ::Type{...}
. So the simplest idea is:
function ChainRules.rrule(::Type{T}, Q::LinearAlgebra.QRCompactWYQ) where {T<:Array}
T(Q), dy -> (NoTangent(), hcat(dy, falses(size(Q,1), size(Q,2)-size(Q.factors,2))))
end;
Cool, that works. Here is a gist that includes code from https://github.com/JuliaDiff/ChainRules.jl/pull/469 into a minimum working example: https://gist.github.com/rkube/b965267944115af7d13b3f00e7533572 This code gives the same results as comparable code for pytorch.
But the syntax to define the pullback for typecasting differs from what is explained in the chainrules.jl docs https://juliadiff.org/ChainRulesCore.jl/stable/
This is for a normal function
function foo(x)
...
end
function rrule(::typeof(foo), args...; kwargs...)
...
return y, pullback
end
And the syntax for typecasting isfunction rrule(::Type{T}...) where {T}
Why is there a difference? It would be great to have this mentioned in the docs.
Might indeed be worth mentioning in the docs, if you can figure out where.
Something similar applies without AD:
julia> Matrix(s::String) = fill(s, 2, 2);
julia> Matrix("zz")
2×2 Matrix{String}:
"zz" "zz"
"zz" "zz"
julia> Matrix{String}("zz")
ERROR: MethodError: no method matching Matrix{String}(::String)
julia> (::Type{T})(s::String) where {T<:Matrix} = fill(s, 3, 3);
julia> Matrix{String}("zz")
3×3 Matrix{String}:
"zz" "zz" "zz"
"zz" "zz" "zz"
"zz" "zz" "zz"
I'm trying to work with the
Q
matrix from a qr-factorization within Zygote. In an incomplete QR factorization for m>=n, theQRCompactWYQ
MatrixQ
hassize=(m,m)
but only the firstn
columns are relevant: https://docs.julialang.org/en/v1/stdlib/LinearAlgebra/#LinearAlgebra.qrUsing this in Zygote behaves as follows:
The first two calls to
gradient
work fine. The last call fails because the incoming gradient issize=(6,4)
but the matrix is ofsize=(6,6)
.I tried modifying the
rrule
from here, https://github.com/JuliaDiff/ChainRules.jl/blob/master/src/rulesets/Base/array.jl line 7 like this:but that didn't work. Is there a good way of making calls to
Matrix(Q)
in a backwards pass? @mcabbott