EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
455 stars 63 forks source link

Implementing Pullbacks example checks fail #467

Closed axla-io closed 2 years ago

axla-io commented 2 years ago

The last checks in the implementing pullbacks guide:

R ≈ A * B            &&
∂z_∂A ≈ ∂z_∂R * B'   &&  # equivalent to Zygote.pullback(*, A, B)[2](∂z_∂R)[1]
∂z_∂B ≈ A' * ∂z_∂R       # equivalent to Zygote.pullback(*, A, B)[2](∂z_∂R)[2]

Returns false. This is not the intended behavior, right?

wsmoses commented 2 years ago

What version of Julia / Enzyme (and commit if reelvant)

andrewjradcliffe commented 2 years ago

I also get the same behavior; it seems to be related to setindex!

# Fixed
function mymul!(R, A, B)
    @assert axes(A,2) == axes(B,1)
    @inbounds for j in axes(B, 2), i in axes(A, 1)
        @inbounds @simd for k in axes(A,2)
            R[i,j] += A[i,k] * B[k,j]
        end
    end
    nothing
end;

A = rand(5, 3)
B = rand(3, 7)

R = zeros(size(A,1), size(B,2));
∂z_∂R = rand(size(R)...)
pz = deepcopy(∂z_∂R);
∂z_∂A = zero(A);
∂z_∂B = zero(B);

autodiff(mymul!, Const, Duplicated(R, ∂z_∂R), Duplicated(A, ∂z_∂A), Duplicated(B, ∂z_∂B))

R ≈ A * B            &&
    ∂z_∂A ≈ ∂z_∂R * B'   &&
    ∂z_∂B ≈ A' * ∂z_∂R
∂z_∂R == pz

# Another way to reproduce the issue
function mymul2!(R, A, B)
    @assert axes(A,2) == axes(B,1)
    @inbounds for j in axes(B, 2), i in axes(A, 1)
        s = zero(eltype(R))
        @inbounds @simd for k in axes(A,2)
            s += A[i,k] * B[k,j]
        end
        R[i,j] = s
    end
    nothing
end;
∂z_∂R2 = deepcopy(pz);
∂z_∂A2 = zero(A);
∂z_∂B2 = zero(B);
R2 = zero(R);
autodiff(mymul2!, Const, Duplicated(R2, ∂z_∂R2), Duplicated(A, ∂z_∂A2), Duplicated(B, ∂z_∂B2))
∂z_∂A ≈ ∂z_∂A2
∂z_∂B ≈ ∂z_∂B2
∂z_∂R2 == zero(R2)

Environment

julia> versioninfo()
Julia Version 1.9.0-DEV.1454
Commit e7a5c36205* (2022-09-28 15:01 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: 96 × Intel(R) Xeon(R) Gold 6336Y CPU @ 2.40GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-14.0.6 (ORCJIT, icelake-server)
  Threads: 96 on 96 virtual cores

Enzyme version/commit: [7da242da] Enzyme v0.10.6 https://github.com/EnzymeAD/Enzyme.jl.git#main

Perhaps the issue is related to master, or something else entirely; upon using Enzyme, I get

julia> using Enzyme
Internal error: encountered unexpected error in runtime:
AssertionError(msg="argextype only works on argument-position values")
argextype at ./compiler/optimize.jl:341
argextype at ./compiler/optimize.jl:323 [inlined]
argextype at ./compiler/optimize.jl:323 [inlined]
stmt_effect_flags at ./compiler/optimize.jl:249
finish at ./compiler/optimize.jl:403
optimize at ./compiler/optimize.jl:493 [inlined]
_typeinf at ./compiler/typeinfer.jl:259
typeinf at ./compiler/typeinfer.jl:215
typeinf_edge at ./compiler/typeinfer.jl:937
abstract_call_method at ./compiler/abstractinterpretation.jl:613
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:152
abstract_call_known at ./compiler/abstractinterpretation.jl:1842
abstract_call at ./compiler/abstractinterpretation.jl:1913
abstract_call at ./compiler/abstractinterpretation.jl:1892
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2043
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2258
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:2459
typeinf_local at ./compiler/abstractinterpretation.jl:2634
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2740
_typeinf at ./compiler/typeinfer.jl:232
typeinf at ./compiler/typeinfer.jl:215
typeinf_ext at ./compiler/typeinfer.jl:1061
typeinf_ext_toplevel at ./compiler/typeinfer.jl:1094
typeinf_ext_toplevel at ./compiler/typeinfer.jl:1090
jfptr_typeinf_ext_toplevel_13157.clone_1 at ~/aradclif/projects/julia/usr/lib/julia/sys.so (unknown line)
jl_apply at ~/aradclif/projects/julia/src/julia.h:1866 [inlined]
jl_type_infer at ~/aradclif/projects/julia/src/gf.c:317
jl_generate_fptr_impl at ~/aradclif/projects/julia/src/jitlayers.cpp:416
jl_compile_method_internal at ~/aradclif/projects/julia/src/gf.c:2156 [inlined]
jl_compile_method_internal at ~/aradclif/projects/julia/src/gf.c:2097
_jl_invoke at ~/aradclif/projects/julia/src/gf.c:2439 [inlined]
ijl_apply_generic at ~/aradclif/projects/julia/src/gf.c:2629
jl_apply at ~/aradclif/projects/julia/src/julia.h:1866 [inlined]
jl_module_run_initializer at ~/aradclif/projects/julia/src/toplevel.c:75
ijl_init_restored_modules at ~/aradclif/projects/julia/src/dump.c:2871
_include_from_serialized at ./loading.jl:928
_require_search_from_serialized at ./loading.jl:1136
_require at ./loading.jl:1412
_require_prelocked at ./loading.jl:1298
macro expansion at ./loading.jl:1278 [inlined]
macro expansion at ./lock.jl:267 [inlined]
require at ./loading.jl:1241
jfptr_require_45015.clone_1 at ~/aradclif/projects/julia/usr/lib/julia/sys.so (unknown line)
jl_apply at ~/aradclif/projects/julia/src/julia.h:1866 [inlined]
call_require at ~/aradclif/projects/julia/src/toplevel.c:466 [inlined]
eval_import_path at ~/aradclif/projects/julia/src/toplevel.c:503
jl_toplevel_eval_flex at ~/aradclif/projects/julia/src/toplevel.c:731
jl_toplevel_eval_flex at ~/aradclif/projects/julia/src/toplevel.c:856
ijl_toplevel_eval_in at ~/aradclif/projects/julia/src/toplevel.c:971
eval at ./boot.jl:370 [inlined]
eval_user_input at ~/aradclif/projects/julia/usr/share/julia/stdlib/v1.9/REPL/src/REPL.jl:152
repl_backend_loop at ~/aradclif/projects/julia/usr/share/julia/stdlib/v1.9/REPL/src/REPL.jl:248
#start_repl_backend#46 at ~/aradclif/projects/julia/usr/share/julia/stdlib/v1.9/REPL/src/REPL.jl:233
start_repl_backend##kw at ~/aradclif/projects/julia/usr/share/julia/stdlib/v1.9/REPL/src/REPL.jl:230 [inlined]
#run_repl#59 at ~/aradclif/projects/julia/usr/share/julia/stdlib/v1.9/REPL/src/REPL.jl:376
run_repl at ~/aradclif/projects/julia/usr/share/julia/stdlib/v1.9/REPL/src/REPL.jl:362
jfptr_run_repl_57428.clone_1 at ~/aradclif/projects/julia/usr/lib/julia/sys.so (unknown line)
#1013 at ./client.jl:421
jfptr_YY.1013_26046.clone_1 at ~/aradclif/projects/julia/usr/lib/julia/sys.so (unknown line)
jl_apply at ~/aradclif/projects/julia/src/julia.h:1866 [inlined]
jl_f__call_latest at ~/aradclif/projects/julia/src/builtins.c:774
#invokelatest#2 at ./essentials.jl:807 [inlined]
invokelatest at ./essentials.jl:804 [inlined]
run_main_repl at ./client.jl:405
exec_options at ./client.jl:322
_start at ./client.jl:522
jfptr__start_32195.clone_1 at ~/aradclif/projects/julia/usr/lib/julia/sys.so (unknown line)
jl_apply at ~/aradclif/projects/julia/src/julia.h:1866 [inlined]
true_main at ~/aradclif/projects/julia/src/jlapi.c:567
jl_repl_entrypoint at ~/aradclif/projects/julia/src/jlapi.c:711
main at ~/aradclif/projects/julia/cli/loader_exe.c:59
__libc_start_main at /lib64/libc.so.6 (unknown line)
unknown function (ip: 0x4010a8)
axla-io commented 2 years ago

I'm running Julia 1.8.1 with Enzyme v0.10.6 on Mac M1 architecture

axla-io commented 2 years ago

Actually it's very interesting, if I run the example by @andrewjradcliffe both versions work.

wsmoses commented 2 years ago

@axla-io when calling reverse mode ∂z_∂R is modified in place and zero'd. Can you cache the previous value and confirm it fails. Admittedly this should probably changed in the guide.

@andrewjradcliffe main is currently broken, can you retry latest release (and or retry once fixed).

andrewjradcliffe commented 2 years ago

@axla-io when calling reverse mode ∂z_∂R is modified in place and zero'd. Can you cache the previous value and confirm it fails. Admittedly this should probably changed in the guide.

@andrewjradcliffe main is currently broken, can you retry latest release (and or retry once fixed).

The guide's example does indeed return false. Substitution of deepcopy'd ∂z_∂R (cached before autodiff call) into the comparison returns true (as one would expect). I still get the above Internal error: encountered unexpected error in runtime: ... upon using Enzyme, but that seems to be orthogonal to the topic of this issue. (test made with Enzyme v0.10.7 and Enzyme_jll v0.0.41).

wsmoses commented 2 years ago

@andrewjradcliffe What version of Julia?

andrewjradcliffe commented 2 years ago

Re-tested today with Enzyme v0.10.8 and

Julia Version 1.9.0-DEV.1515
Commit 92e68c8707* (2022-10-06 01:06 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: 96 × Intel(R) Xeon(R) Gold 6336Y CPU @ 2.40GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-14.0.6 (ORCJIT, icelake-server)
  Threads: 96 on 96 virtual cores

-- same result as yesterday. It seems that as far as correctness of the pullback, there are no problems (perhaps the guide should be updated, as you noted above).

wsmoses commented 2 years ago

Yeah the expected behavior is that it modifies the differential return (since otherwise loops would cause wrong values in mutation).

Would you like to update the guide to show keep a derivative before that doesn't get modified to compare against for clarity

andrewjradcliffe commented 2 years ago

Certainly; I'll likely get around to it tomorrow afternoon.

wsmoses commented 2 years ago

bump @andrewjradcliffe