EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
439 stars 62 forks source link

Mutation of gradient input in MatMul example for Square Matrices #382

Closed maximilian-gelbrecht closed 2 years ago

maximilian-gelbrecht commented 2 years ago

Looking at the example for a matrix multiply and playing with it I noticed the following behaviour that depends on the matrix size

The example but for the standard LinearAlgebra.mul! is the following:

using LinearAlgebra, Enzyme

begin 
    A = rand(7, 3)
    B = rand(3, 5)

    R = zeros(size(A,1), size(B,2))
    ∂z_∂R = rand(size(R)...)  # Some gradient/tangent passed to us
    ∂z_∂R_copy = deepcopy(∂z_∂R)

    ∂z_∂A = zero(A)
    ∂z_∂B = zero(B)
end

Enzyme.autodiff(mul!, Const, Duplicated(R, ∂z_∂R), Duplicated(A, ∂z_∂A), Duplicated(B, ∂z_∂B))

∂z_∂R ≈ ∂z_∂R_copy # true

R ≈ A * B            &&
∂z_∂A ≈ ∂z_∂R * B'   &&  
∂z_∂B ≈ A' * ∂z_∂R       # true

This works as intended.

If I change the matrix size to square matrices, the gradient input ∂z_∂R is mutated to zeros, while the computed gradients ∂z_∂A and ∂z_∂B are still correct :

begin 
    A = rand(3, 3)
    B = rand(3, 3)

    R = zeros(size(A,1), size(B,2))
    ∂z_∂R = rand(size(R)...)  # Some gradient/tangent passed to us
    ∂z_∂R_copy = deepcopy(∂z_∂R)

    ∂z_∂A = zero(A)
    ∂z_∂B = zero(B)
end

Enzyme.autodiff(mul!, Const, Duplicated(R, ∂z_∂R), Duplicated(A, ∂z_∂A), Duplicated(B, ∂z_∂B))

∂z_∂R ≈ ∂z_∂R_copy # false 
∂z_∂R ≈ zero(R) # true 

R ≈ A * B            &&
∂z_∂A ≈ ∂z_∂R * B'   &&  
∂z_∂B ≈ A' * ∂z_∂R        # false 

R ≈ A * B            &&
∂z_∂A ≈ ∂z_∂R_copy * B'   &&  
∂z_∂B ≈ A' * ∂z_∂R_copy       # true

Is this behaviour intentional or a bug?

(Running Enzyme v0.10.4)

wsmoses commented 2 years ago

This behavior is intended and how mutation support works. In essence for reverse mode the derivative outputs are propagated to the derivative inputs. Specifically, when reverse-mode differentiating the store into R, it will propagate dR to its inputs, then zero dR.

This is required, for example, when something is called inside of a loop, like below.

for i = 1 : 10
  R = A * B
end

Only the last store matters, and the correct behavior is as follows:

dA = 0
dB = 0
for i = 10:1
   dA += dR * B
   dB += dR * A
   dR = 0
end

if the zeroing of dR were not there, then dA would equal 10 times its actual correct derivative!

This zero'ing behavior only applies to reverse mode, when using forward mode AD the original derivative inputs aren't modified.

wsmoses commented 2 years ago

Oh I see what you mean by the issue, the square one has the intended behavior here but the non-square one does not. My guess is that internally the non-square version has an internal copy inside of the linear algebra routine (meaning the true dR being used is different), but nonetheless will take a look.

maximilian-gelbrecht commented 2 years ago

Okay, got it. Probably not that important, I was just confused by the difference in behaviour. Maybe the documentation / example should changed though, as it currently shows unintended behaviour.

wsmoses commented 2 years ago

Reducing:

wmoses@beast:~/git/Enzyme.jl ((HEAD detached at origin/main)) $ cat mm.jl 
using Enzyme
using LinearAlgebra

    A = rand(1, 1)
    B = rand(1, 1)

    R = zeros(size(A,1), size(B,2))
    ∂z_∂R = [0.8;;]
    ∂z_∂R_copy = deepcopy(∂z_∂R)

    ∂z_∂A = zero(A)
    ∂z_∂B = zero(B)

function mul(R, A, B)
  return BLAS.gemm!('N', 'N', 1.0, A, B, 0.0, R)
  nothing
end
Enzyme.API.printall!(true)
Enzyme.autodiff(mul, Const, Duplicated(R, ∂z_∂R), Duplicated(A, ∂z_∂A), Duplicated(B, ∂z_∂B))

@show ∂z_∂R # should be 0, is [0.8;;]
wsmoses commented 2 years ago
using Enzyme
using LinearAlgebra
using LinearAlgebra.BLAS

@inline uptr(x) = Base.reinterpret(Ptr{Float64}, x)

             # SUBROUTINE DGEMM(TRANSA,TRANSB,M,N,K,ALPHA,A,LDA,B,LDB,BETA,C,LDC)
             # *     .. Scalar Arguments ..
             #       DOUBLE PRECISION ALPHA,BETA
             #       INTEGER K,LDA,LDB,LDC,M,N
             #       CHARACTER TRANSA,TRANSB
             # *     .. Array Arguments ..
             #       DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*)
        function imul(C, A, m)
            ka = 1
            kb = 1
            n = 1
            A = uptr(A) #Ref(2.0)
            # A = Ref(2.0)
            B = Ref(2.0)
            ccall((LinearAlgebra.BLAS.@blasfunc(dgemm_), LinearAlgebra.BLAS.libblastrampoline), Cvoid,
                (Ref{UInt8}, Ref{UInt8}, Ref{LinearAlgebra.BLAS.BlasInt}, Ref{LinearAlgebra.BLAS.BlasInt},
                 Ref{LinearAlgebra.BLAS.BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{LinearAlgebra.BLAS.BlasInt},
                 Ptr{Float64}, Ref{LinearAlgebra.BLAS.BlasInt}, Ref{Float64}, Ptr{Float64},
                 Ref{LinearAlgebra.BLAS.BlasInt}, Clong, Clong),
                 'N', 'N', m, n,
                 ka, 1.0, A, 1,
                 B, 1, 0.0, uptr(C),
                 1, 1, 1)
            nothing
        end

    function mul(C, A, B, m)
        imul(C, A, m)
    end

    A = rand(1, 1)
    B = rand(1, 1)

    R = zeros(size(A,1), size(B,2))
    ∂z_∂R = [0.8;;]
    ∂z_∂R_copy = deepcopy(∂z_∂R)

    ∂z_∂A = zero(A)
    ∂z_∂B = zero(B)

@inline ptr(x) = Base.reinterpret(Core.LLVMPtr{Float64, 0}, Base.unsafe_convert(Ptr{Float64}, x))

GC.@preserve R A B ∂z_∂R ∂z_∂A ∂z_∂B begin
    mul(ptr(R), ptr(A), ptr(B), 1)

    Enzyme.API.printall!(true)
    Enzyme.autodiff(mul, Const, 
        Duplicated(ptr(R), ptr(∂z_∂R)),
        Duplicated(ptr(A), ptr(∂z_∂A)), Duplicated(ptr(B), ptr(∂z_∂B)), Const(1))

    @show ∂z_∂R

end