LuxDL / Lux.jl

Elegant & Performant Scientific Machine Learning in Julia
https://lux.csail.mit.edu/
MIT License
483 stars 58 forks source link

Add enzyme support for loss functions from LossFunctions.jl #736

Closed avik-pal closed 3 months ago

avik-pal commented 3 months ago

Needs testing the loss functions file for gradient correctness

codecov[bot] commented 3 months ago

Codecov Report

Attention: Patch coverage is 70.90909% with 16 lines in your changes missing coverage. Please review.

Project coverage is 77.49%. Comparing base (236284c) to head (c4bf97f).

Files Patch % Lines
src/enzymerules.jl 0.00% 12 Missing :warning:
src/contrib/training.jl 0.00% 2 Missing :warning:
src/layers/basic.jl 80.00% 2 Missing :warning:

:exclamation: There is a different number of reports uploaded between BASE (236284c) and HEAD (c4bf97f). Click for more details.

HEAD has 18 uploads less than BASE | Flag | BASE (236284c) | HEAD (c4bf97f) | |------|------|------| ||31|13|
Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #736 +/- ## =========================================== - Coverage 96.81% 77.49% -19.33% =========================================== Files 53 54 +1 Lines 2702 2724 +22 =========================================== - Hits 2616 2111 -505 - Misses 86 613 +527 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

avik-pal commented 3 months ago

@wsmoses what does this mean?

ERROR: AssertionError: Enzyme Internal Error: Illegal calling convention fixup
; Function Attrs: alwaysinline noreturn
define void @julia_custom_rule_method_error_7985(i64 zeroext %0, [1 x float] addrspace(11)* nocapture nofree noundef nonnull readonly align 4 dereferenceable(4) %1, [2 x {} addrspace(10)*] addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(16) %2, [2 x {} addrspace(10)*] addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(16) %3) local_unnamed_addr #43 !dbg !1995 {
top:
  %4 = call {}*** @julia.get_pgcstack()
  %5 = getelementptr inbounds [1 x float], [1 x float] addrspace(11)* %1, i64 0, i64 0
  %unbox.unpack = load float, float addrspace(11)* %5, align 4, !tbaa !48, !alias.scope !59, !noalias !62
  %unbox1.elt = getelementptr inbounds [2 x {} addrspace(10)*], [2 x {} addrspace(10)*] addrspace(11)* %2, i64 0, i64 0
  %unbox1.unpack = load {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %unbox1.elt, align 8, !tbaa !48, !alias.scope !59, !noalias !62
  %unbox1.elt14 = getelementptr inbounds [2 x {} addrspace(10)*], [2 x {} addrspace(10)*] addrspace(11)* %2, i64 0, i64 1
  %unbox1.unpack15 = load {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %unbox1.elt14, align 8, !tbaa !48, !alias.scope !59, !noalias !62
  %unbox2.elt = getelementptr inbounds [2 x {} addrspace(10)*], [2 x {} addrspace(10)*] addrspace(11)* %3, i64 0, i64 0
  %unbox2.unpack = load {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %unbox2.elt, align 8, !tbaa !48, !alias.scope !59, !noalias !62
  %unbox2.elt17 = getelementptr inbounds [2 x {} addrspace(10)*], [2 x {} addrspace(10)*] addrspace(11)* %3, i64 0, i64 1
  %unbox2.unpack18 = load {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %unbox2.elt17, align 8, !tbaa !48, !alias.scope !59, !noalias !62
  %current_task320 = getelementptr inbounds {}**, {}*** %4, i64 -14
  %current_task3 = bitcast {}*** %current_task320 to {}**
  %ptls_field21 = getelementptr inbounds {}**, {}*** %4, i64 2
  %6 = bitcast {}*** %ptls_field21 to i64***
  %ptls_load2223 = load i64**, i64*** %6, align 8, !tbaa !44
  %7 = getelementptr inbounds i64*, i64** %ptls_load2223, i64 2
  %safepoint = load i64*, i64** %7, align 8, !tbaa !48
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint), !dbg !1996
  fence syncscope("singlethread") seq_cst
  %box = call noalias nonnull dereferenceable(40) {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task3, i64 noundef 40, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 123185482555216 to {}*) to {} addrspace(10)*)) #44, !dbg !1997
  %8 = bitcast {} addrspace(10)* %box to i8 addrspace(10)*, !dbg !1997
  %args.sroa.0.0..sroa_cast = bitcast {} addrspace(10)* %box to float addrspace(10)*, !dbg !1997
  store float %unbox.unpack, float addrspace(10)* %args.sroa.0.0..sroa_cast, align 8, !dbg !1997, !tbaa !75, !alias.scope !76, !noalias !2000
  %args.sroa.312.0..sroa_idx = getelementptr inbounds i8, i8 addrspace(10)* %8, i64 8, !dbg !1997
  %args.sroa.312.0..sroa_cast = bitcast i8 addrspace(10)* %args.sroa.312.0..sroa_idx to {} addrspace(10)* addrspace(10)*, !dbg !1997
  store {} addrspace(10)* %unbox1.unpack, {} addrspace(10)* addrspace(10)* %args.sroa.312.0..sroa_cast, align 8, !dbg !1997, !tbaa !75, !alias.scope !76, !noalias !2000
  %args.sroa.5.0..sroa_idx = getelementptr inbounds i8, i8 addrspace(10)* %8, i64 16, !dbg !1997
  %args.sroa.5.0..sroa_cast = bitcast i8 addrspace(10)* %args.sroa.5.0..sroa_idx to {} addrspace(10)* addrspace(10)*, !dbg !1997
  store {} addrspace(10)* %unbox1.unpack15, {} addrspace(10)* addrspace(10)* %args.sroa.5.0..sroa_cast, align 8, !dbg !1997, !tbaa !75, !alias.scope !76, !noalias !2000
  %args.sroa.7.0..sroa_idx = getelementptr inbounds i8, i8 addrspace(10)* %8, i64 24, !dbg !1997
  %args.sroa.7.0..sroa_cast = bitcast i8 addrspace(10)* %args.sroa.7.0..sroa_idx to {} addrspace(10)* addrspace(10)*, !dbg !1997
  store {} addrspace(10)* %unbox2.unpack, {} addrspace(10)* addrspace(10)* %args.sroa.7.0..sroa_cast, align 8, !dbg !1997, !tbaa !75, !alias.scope !76, !noalias !2000
  %args.sroa.9.0..sroa_idx = getelementptr inbounds i8, i8 addrspace(10)* %8, i64 32, !dbg !1997
  %args.sroa.9.0..sroa_cast = bitcast i8 addrspace(10)* %args.sroa.9.0..sroa_idx to {} addrspace(10)* addrspace(10)*, !dbg !1997
  store {} addrspace(10)* %unbox2.unpack18, {} addrspace(10)* addrspace(10)* %args.sroa.9.0..sroa_cast, align 8, !dbg !1997, !tbaa !75, !alias.scope !76, !noalias !2000
  %box6 = call noalias nonnull dereferenceable(24) {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task3, i64 noundef 24, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 123189147238144 to {}*) to {} addrspace(10)*)) #44, !dbg !1999
  %9 = bitcast {} addrspace(10)* %box6 to { {} addrspace(10)*, {} addrspace(10)*, i64 } addrspace(10)*, !dbg !1999
  %.repack = getelementptr inbounds { {} addrspace(10)*, {} addrspace(10)*, i64 }, { {} addrspace(10)*, {} addrspace(10)*, i64 } addrspace(10)* %9, i64 0, i32 0, !dbg !1999
  store {} addrspace(10)* addrspacecast ({}* inttoptr (i64 123189002965984 to {}*) to {} addrspace(10)*), {} addrspace(10)* addrspace(10)* %.repack, align 8, !dbg !1999, !tbaa !80, !alias.scope !84, !noalias !2003
  %.repack24 = getelementptr inbounds { {} addrspace(10)*, {} addrspace(10)*, i64 }, { {} addrspace(10)*, {} addrspace(10)*, i64 } addrspace(10)* %9, i64 0, i32 1, !dbg !1999
  store {} addrspace(10)* %box, {} addrspace(10)* addrspace(10)* %.repack24, align 8, !dbg !1999, !tbaa !80, !alias.scope !84, !noalias !2003
  %.repack26 = getelementptr inbounds { {} addrspace(10)*, {} addrspace(10)*, i64 }, { {} addrspace(10)*, {} addrspace(10)*, i64 } addrspace(10)* %9, i64 0, i32 2, !dbg !1999
  store i64 %0, i64 addrspace(10)* %.repack26, align 8, !dbg !1999, !tbaa !80, !alias.scope !84, !noalias !2003
  %10 = addrspacecast {} addrspace(10)* %box6 to {} addrspace(12)*, !dbg !1999
  call void @ijl_throw({} addrspace(12)* %10) #45, !dbg !1999
  unreachable, !dbg !1999
}

args = LLVM.Value[LLVM.AddrSpaceCastInst(%42 = addrspacecast [1 x float] addrspace(10)* %41 to [1 x float] addrspace(11)*, !dbg !59), LLVM.ConstantInt(0x000000000e292da0), LLVM.AddrSpaceCastInst(%21 = addrspacecast [2 x {} addrspace(10)*] addrspace(10)* %20 to [2 x {} addrspace(10)*] addrspace(11)*, !dbg !59), LLVM.AddrSpaceCastInst(%31 = addrspacecast [2 x {} addrspace(10)*] addrspace(10)* %30 to [2 x {} addrspace(10)*] addrspace(11)*, !dbg !59)]
i = 1
args[i] = LLVM.AddrSpaceCastInst(%42 = addrspacecast [1 x float] addrspace(10)* %41 to [1 x float] addrspace(11)*, !dbg !59)
party = LLVM.IntegerType(i64)
ctype = LLVM.PointerType([1 x float] addrspace(11)*)
tape = LLVM.IntegerType(i64)
val =   %42 = addrspacecast [1 x float] addrspace(10)* %41 to [1 x float] addrspace(11)*, !dbg !59
prev = i64 undef
lidxs = UInt32[]
ridxs = UInt32[]
tape_type(tape) = UInt64
convert(LLVMType, tape_type(tape)) = LLVM.IntegerType(i64)
wsmoses commented 3 months ago

Julia's GC did something I didn't expect and we threw an early error rather than risk GC segfault

wsmoses commented 3 months ago

open an issue with a mwe?

avik-pal commented 3 months ago

This appears only when I define a custom rule. I will isolate and open one

avik-pal commented 3 months ago

BTW this is a very pathological use case that uses type unstable code (and shouldn't occur unless the user is actively trying to mess with the training functions)