CarloLucibello / GraphNeuralNetworks.jl

Graph Neural Networks in Julia
https://carlolucibello.github.io/GraphNeuralNetworks.jl/dev/
MIT License
215 stars 46 forks source link

Gradient of edge weights is nothing with fused e_mul_xj #113

Closed learning-chip closed 2 years ago

learning-chip commented 2 years ago

107 breaks Zygote autodiff. Zygote.gradient() returns nothing for the fused kernel, while returns correct gradient for the unfused one. This bug further breaks GNN training, with hard-to-understand error like MethodError: no method matching vec(::Nothing)

To reproduce

using GraphNeuralNetworks
using SparseArrays
import Random: seed!
using Zygote

n = 32
seed!(0)
A = sprand(n, n, 0.1)
b = rand(1, n)
g = GNNGraph(A)
A_val = reshape(A.nzval, 1, :)

"""SpMV followed by a scalar loss function"""
function forward_fused(g, b, A_val)
    out = propagate(
        e_mul_xj, g, +; xj=b, e=A_val
        )
    return sum(abs2, out)
end

function forward_unfused(g, b, A_val)
    out = propagate(
        (xi, xj, e) -> e .* xj, g, +; xj=b, e=A_val
        )
    return sum(abs2, out)
end

forward_fused(g, b, vec(A_val)) == forward_unfused(g, b, A_val)  # true, forward passes agree

grad_builtin = gradient(A -> sum(abs2, b * A), A)[1];  # turns a sparse CSC matrix containing gradient

grad_gnn1 = gradient(
    A_vals -> forward_unfused(g, b, A_vals), 
    A_val
)[1]

isequal(vec(grad_gnn1), grad_builtin.nzval)  # true, gradient agree with reference

# not flatten edge feature, so the “fused function” not actually invoking the fused kernel
grad_gnn2 = gradient(
    A_vals -> forward_fused(g, b, A_vals), 
    A_val
)[1]

isequal(vec(grad_gnn2), grad_builtin.nzval)   # true, gradient agree with reference

# passing flattened edge feature, activating fusion
grad_gnn3 = gradient(
    A_vals -> forward_fused(g, b, A_vals), 
    vec(A_val)
)[1]  # bug, turns nothing

Pacakge version

CarloLucibello commented 2 years ago

Unfortunately seems hard to support gradient with respect to edge_weights when doing fused operations.
I don't know how to make the construction of a matrix out of a vector differentiable.

CarloLucibello commented 2 years ago

We have to solve this issue:

julia> using SparseArrays, Zygote

julia> s, t, w = [1,2], [2,3], [0.5,0.5]
([1, 2], [2, 3], [0.5, 0.5])

julia> gradient(w -> sum(sparse(s,t,w)), w)
ERROR: Need an adjoint for constructor SparseMatrixCSC{Float64, Int64}. Gradient is of type FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.Jnew{SparseMatrixCSC{Float64, Int64}, Nothing, false})(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/lib/lib.jl:324
  [3] (::Zygote.var"#1786#back#229"{Zygote.Jnew{SparseMatrixCSC{Float64, Int64}, Nothing, false}})(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [4] Pullback
    @ /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/stdlib/v1.7/SparseArrays/src/sparsematrix.jl:31 [inlined]
  [5] (::typeof(∂(SparseMatrixCSC{Float64, Int64})))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface2.jl:0
  [6] Pullback
    @ /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/stdlib/v1.7/SparseArrays/src/sparsematrix.jl:44 [inlined]
  [7] (::typeof(∂(SparseMatrixCSC)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface2.jl:0
  [8] Pullback
    @ /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/stdlib/v1.7/SparseArrays/src/sparsematrix.jl:961 [inlined]
  [9] (::typeof(∂(sparse!)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface2.jl:0
 [10] Pullback
    @ /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/stdlib/v1.7/SparseArrays/src/sparsematrix.jl:798 [inlined]
 [11] (::typeof(∂(sparse)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface2.jl:0
 [12] Pullback
    @ /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/stdlib/v1.7/SparseArrays/src/sparsematrix.jl:987 [inlined]
 [13] (::typeof(∂(sparse)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface2.jl:0
 [14] Pullback
    @ /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/stdlib/v1.7/SparseArrays/src/sparsematrix.jl:983 [inlined]
 [15] (::typeof(∂(sparse)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface2.jl:0
 [16] Pullback
    @ ./REPL[7]:1 [inlined]
 [17] (::typeof(∂(#6)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface2.jl:0
 [18] (::Zygote.var"#57#58"{typeof(∂(#6))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface.jl:41
 [19] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface.jl:76
 [20] top-level scope
    @ REPL[7]:1
learning-chip commented 2 years ago
gradient(w -> sum(sparse(s,t,w)), w)

This is definitely doable with PyTorch sparse tensors. I am not familiar enough with Zygote's autodiff rules (still learning), but let me just post the PyTorch (1.9.1) code for reference...

import torch

# gradient w.r.t. to value of sparse matrix
A_indices = torch.tensor([[0, 1], [1, 2]])
A_vals = torch.tensor([0.5, 0.5], requires_grad=True)
A_coo = torch.sparse_coo_tensor(A_indices, A_vals, (3, 3))  # or torch.sparse_csr_tensor()
B = torch.sparse.mm(A_coo, A_coo)  # some sparse linear algebra
loss = B.coalesce().values().pow(2).sum()
loss.backward()

A_vals.grad  # tensor([0.2500, 0.2500])

# gradient w.r.t. sparse matrix itself
A_new = A_coo.detach().requires_grad_(True)  # new leaf node
loss2 = torch.sparse.mm(A_new, A_new).coalesce().values().pow(2).sum()
loss2.backward()

A_new.grad  # a sparse tensor with same pattern as A and value of tensor([0.2500, 0.2500])

torch.equal(A_vals.grad, A_new.grad.coalesce().values())  # True
CarloLucibello commented 2 years ago

Differentiability of sparse is being taken care of in https://github.com/JuliaDiff/ChainRules.jl/pull/579