FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

AD through custom AbstractMatrix fails #815

Closed simonmandlik closed 4 years ago

simonmandlik commented 4 years ago

Is it possible with Zygote to build custom AbstractMatrix subtypes with custom adjoints? I would like to have a matrix, that would behave like a standard dense matrix, but before left-multiplying a data matrix it fills all missing elements in that matrix by values from a vector of parameters:

using Flux, Zygote
using ChainRulesCore
import ChainRulesCore: rrule
import Base: *

struct MyMatrix{T, U <: AbstractMatrix{T}, V <: AbstractVector{T}}
    W::U
    b::V
end

A::MyMatrix * B::AbstractMatrix{Union{Missing, Float64}} = A.W * _fill_in(A.b, B)

_fill_in(b, B) = _fill_mask(b, B)[1]
function rrule(::typeof(_fill_in), b, B)
    X, m = _fill_mask(b, B)
    X, Δ -> (NO_FIELDS, @thunk(_fill_in_db(Δ, m)), @thunk(_fill_in_dB(Δ, .!m)))
end

_fill_in_db(Δ, m) = (db = deepcopy(Δ); db[m] .= 0; sum(db, dims=2))
_fill_in_dB(Δ, m) = (dB = deepcopy(Δ); dB[m] .= 0; dB)

function _fill_mask(b, B)
    m = .!ismissing.(B)
    X = repeat(b, 1, size(B, 2))
    X[m] = B[m]
    X, m
end

and everything works as expected:

julia> B = [1.0 2.0; missing 3.0]
2×2 Array{Union{Missing, Float64},2}:
 1.0       2.0
  missing  3.0

julia> W, b = rand(2,2), rand(2)
([0.7271867875139644 0.9966562367102048; 0.7544782519268483 0.9148648966908355], [0.1202546404757574, 0.5564718739401007])

julia> MyMatrix(W, b) * B
2×2 Array{Float64,2}:
 1.2818   4.44434
 1.26357  4.25355

julia> gradient(A -> sum(A*B), MyMatrix(W, b))
((W = [3.0 3.556471873940101; 3.0 3.556471873940101], b = [0.0; 1.9115211334010402]),)

julia> gradient(B -> sum(MyMatrix(W,b)*B), B)
([1.4816650394408126 1.4816650394408126; 0.0 1.9115211334010402],)

However if I change the struct definition to:

struct MyMatrix{T, U <: AbstractMatrix{T}, V <: AbstractVector{T}} <: AbstractMatrix{T}
    W::U
    b::V
end

this stops working and gives this error:

julia> gradient(B -> sum(MyMatrix(W,b)*B), B)
ERROR: MethodError: no method matching copy(::Missing)
...

julia> gradient(A -> sum(A*B), MyMatrix(W, b))
ERROR: MethodError: no method matching copy(::Missing)
...

Is it possible to create custom AbstractMatrix subtypes in this way or does Zygote make it impossible?

simonmandlik commented 4 years ago

Probably related to https://github.com/FluxML/Zygote.jl/issues/811

simonmandlik commented 4 years ago

The problem was that there is a definition of adjoint for multiplication of two matrices. I have tried to implement an adjoint specifically for * using rrule, which fails:

function rrule( A::typeof(*), ::MyMatrix, B::AbstractMatrix) = ...

once it is defined with @adjoint it works as expected:

Zygote.@adjoint A::MyMatrix * B::AbstractMatrix = ...

so I suspect this is a duplicate of https://github.com/FluxML/Zygote.jl/issues/811