FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.46k stars 603 forks source link

Create a flag to use Enzyme as the AD in training/etc. #2443

Closed wsmoses closed 2 months ago

wsmoses commented 3 months ago

Motivation and description

Now that all the internal Flux tests pass, we should start setting up for integration. Having such a flag would make it easier for myself and others to test things out, debug, etc.

Possible Implementation

No response

wsmoses commented 3 months ago

cc @CarloLucibello @ToucheSir

CarloLucibello commented 3 months ago

I think the basic interface needed is a nice gradient function. This code is still not working though, on both cpu and cuda gpu:

using CUDA # for GPU training
using Flux, Enzyme
using Random, Statistics

_make_zero(x::Union{Number,AbstractArray}) = zero(x)
_make_zero(x) = x
make_zero(model) = fmap(_make_zero, model)

function gradient_ez(f, x...)
    args = []
    for x in x
        if x isa Number
            push!(args, Active(x))
        else
            push!(args, Duplicated(x, make_zero(x)))
        end
    end
    ret = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...)
    g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
    return g
end

batch_size = 128
feature_size = 784
num_classes = 10
epochs = 100
device = Flux.cpu        # CPU training
# device = Flux.gpu      # GPU training

X = randn(Float32, feature_size, batch_size) |> device
y = Flux.onehotbatch(rand(1:num_classes, batch_size), 1:num_classes) |> device

model = Chain(Dense(feature_size => 32, relu),
              Dense(32, num_classes)) |> device

opt_state = Flux.setup(Adam(1e-3), model)

loss(model, x, y) = Flux.logitcrossentropy(model(x), y)
accuracy(model, x, y) = mean(Flux.onecold(model(x)) .== Flux.onecold(y))

function report(epoch)
    @info "Epoch: $epoch" loss=loss(model, X, y) accuracy=accuracy(model, X, y)
end

report(0)
for epoch in 1:epochs
    g = gradient_ez(model -> loss(model, X, y), model)[1]     # Enzyme gradient
    # g = Flux.gradient(model -> loss(model, X, y), model)[1] # Zygote gradient
    Flux.update!(opt_state, model, g)
    report(epoch)
end

We should add tests for the loss functions. This one is failing:

gradient_ez(ŷ -> Flux.logitcrossentropy(ŷ, y), randn(Float32, num_classes, batch_size))
wsmoses commented 3 months ago

A modification to your code above which will be more performant/stable/etc (closures are bad).

In any case still has the same issue and will investigate

# using CUDA # for GPU training
using Flux, Enzyme
using Random, Statistics

_make_zero!(x::AbstractArray) = x .= 0
_make_zero!(x) = x
make_zero!(model) = fmap(_make_zero!, model)

batch_size = 128
feature_size = 784
num_classes = 10
epochs = 100
device = Flux.cpu        # CPU training
# device = Flux.gpu      # GPU training

X = randn(Float32, feature_size, batch_size) |> device
y = Flux.onehotbatch(rand(1:num_classes, batch_size), 1:num_classes) |> device

model = Chain(Dense(feature_size => 32, relu),
              Dense(32, num_classes)) |> device

opt_state = Flux.setup(Adam(1e-3), model)

loss(model, x, y) = Flux.logitcrossentropy(model(x), y)
accuracy(model, x, y) = mean(Flux.onecold(model(x)) .== Flux.onecold(y))

function report(epoch)
    @info "Epoch: $epoch" loss=loss(model, X, y) accuracy=accuracy(model, X, y)
end

report(0)
g = deepcopy(model)
for epoch in 1:epochs
    make_zero!(g)
    Enzyme.autodiff(Reverse, loss, Duplicated(model, g), Const(X), Const(y))
    # g = Flux.gradient(model -> loss(model, X, y), model)[1] # Zygote gradient
    Flux.update!(opt_state, model, g)
    report(epoch)
end
wsmoses commented 3 months ago

Yeah this works now with the NNlib type stability fix https://github.com/FluxML/NNlib.jl/pull/584

darsnack commented 3 months ago

The previous "interface" was to import the corresponding AD package and just call e.g. Tracker.withgradient.

The most recent attempt was supposed to be DI.jl, but the choice to focus on arrays and single inputs means we can't use it.

To me the best option would be a Flux.gradient (and Flux.withgradient) that uses ADTypes.jl (only to avoid further fragmentation). Alternatively, a small package that wraps Enzyme.autodiff + make_zero in a Zygote-like interface (similar to what's above).

But I suggest a dedicated doc page on using Enzyme + Flux will be easier to get through quickly.

wsmoses commented 3 months ago

Sure, I think docs would be a great first start. I don't really know how to use Flux or where that would go best, so I'll leave that to you.

At the same time, if we're already doing API design, for training it would be nice to not have to constantly reallocate the gradient buffer (with make_zero). I don't know if there's an in-place zeroing function you have for models, but that would be highly beneficial here.

CarloLucibello commented 3 months ago

it would be nice to not have to constantly reallocate the gradient buffer

I edited the code in your post to zero the gradient in-place. A slight problem in make_zero! is that it sets to zero the arrays but not the scalar field, so those are going to be accumulated. That can be fixed later and in principle it is not even a problem since scalars are not updated bu the optimizer.

CarloLucibello commented 3 months ago

On gpu I get the following error

error ┌ Warning: active variables passed by value to jl_new_task are not yet supported └ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/kqxyC/src/utils.jl:59 ERROR: Enzyme compilation failed due to illegal type analysis. Current scope: ; Function Attrs: mustprogress willreturn define internal fastcc void @preprocess_julia_fill__33038({} addrspace(10)* nocapture noundef nonnull readonly align 8 dereferenceable(40) "enzyme_type"="{[-1]:Pointer, [-1,0]:Pointer, [-1,0,0]:Pointer, [-1,0,0,0]:Pointer, [-1,0,0,0,0]:Pointer, [-1,0,0,0,8]:Integer, [-1,0,0,0,16]:Pointer, [-1,0,0,16]:Integer, [-1,0,0,17]:Integer, [-1,0,0,18]:Integer, [-1,0,0,19]:Integer, [-1,0,0,20]:Integer, [-1,0,0,21]:Integer, [-1,0,0,22]:Integer, [-1,0,0,23]:Integer, [-1,0,0,24]:Integer, [-1,0,0,32]:Pointer, [-1,0,0,40]:Pointer, [-1,0,0,40,-1]:Integer, [-1,0,8]:Integer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}" "enzymejl_parmtype"="139959628162192" "enzymejl_parmtype_ref"="2" %0, float "enzyme_type"="{[-1]:Float@float}" "enzymejl_parmtype"="139978039813152" "enzymejl_parmtype_ref"="0" %1) unnamed_addr #657 !dbg !47671 { top: %2 = call {}*** @julia.get_pgcstack() %3 = call {}*** @julia.get_pgcstack() %4 = bitcast {}*** %2 to {}** %5 = getelementptr inbounds {}*, {}** %4, i64 -14 %6 = getelementptr inbounds {}*, {}** %5, i64 16 %7 = bitcast {}** %6 to i8** %8 = load i8*, i8** %7, align 8 %9 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) {} addrspace(10)* @julia.gc_alloc_obj({}** %5, i64 8, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 139961738084176 to {}*) to {} addrspace(10)*)), !enzyme_fromstack !615 call void @zeroType.457({} addrspace(10)* %9, i8 0, i64 8), !enzyme_zerostack !590 %phic1 = bitcast {} addrspace(10)* %9 to {} addrspace(10)* addrspace(10)*, !enzyme_caststack !590 %10 = bitcast {}*** %3 to {}** %11 = getelementptr inbounds {}*, {}** %10, i64 -14 %12 = getelementptr inbounds {}*, {}** %11, i64 16 %13 = bitcast {}** %12 to i8** %14 = load i8*, i8** %13, align 8 %15 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) {} addrspace(10)* @julia.gc_alloc_obj({}** %11, i64 8, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 139961738084176 to {}*) to {} addrspace(10)*)), !enzyme_fromstack !615 call void @zeroType.456({} addrspace(10)* %15, i8 0, i64 8), !enzyme_zerostack !590 %phic = bitcast {} addrspace(10)* %15 to {} addrspace(10)* addrspace(10)*, !enzyme_caststack !590 %phic19 = call noalias nonnull dereferenceable(1) dereferenceable_or_null(1) i8* @malloc(i64 1), !enzyme_fromstack !4822 %16 = call {}*** @julia.get_pgcstack() #658 store {} addrspace(10)* null, {} addrspace(10)* addrspace(10)* %phic1, align 8, !noalias !47672 call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %9, {} addrspace(10)* null) store {} addrspace(10)* null, {} addrspace(10)* addrspace(10)* %phic, align 8, !noalias !47672 call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %15, {} addrspace(10)* null) %current_task329 = getelementptr inbounds {}**, {}*** %16, i64 -14 %current_task3 = bitcast {}*** %current_task329 to {}** %ptls_field30 = getelementptr inbounds {}**, {}*** %16, i64 2 %17 = bitcast {}*** %ptls_field30 to i64*** %ptls_load3132 = load i64**, i64*** %17, align 8, !tbaa !591 %18 = getelementptr inbounds i64*, i64** %ptls_load3132, i64 2 %safepoint = load i64*, i64** %18, align 8, !tbaa !595 fence syncscope("singlethread") seq_cst call void @julia.safepoint(i64* %safepoint) #658, !dbg !47675 fence syncscope("singlethread") seq_cst %bitcast_coercion = bitcast float %1 to i32, !dbg !47676 %19 = addrspacecast {} addrspace(10)* %0 to {} addrspace(10)* addrspace(11)*, !dbg !47678 %getfield = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %19 unordered, align 8, !dbg !47678, !tbaa !602, !alias.scope !606, !noalias !609, !nonnull !590, !dereferenceable !614, !align !615 %20 = addrspacecast {} addrspace(10)* %getfield to i8 addrspace(11)*, !dbg !47681 %21 = getelementptr inbounds i8, i8 addrspace(11)* %20, i64 8, !dbg !47681 %22 = load i8, i8 addrspace(11)* %21, align 8, !dbg !47681, !tbaa !602, !alias.scope !606, !noalias !609 %23 = and i8 %22, 1, !dbg !47681 %.not = icmp eq i8 %23, 0, !dbg !47681 br i1 %.not, label %L8, label %L5, !dbg !47682 L5: ; preds = %top %24 = call fastcc [1 x {} addrspace(10)*] @julia_ArgumentError_31098({} addrspace(10)* nofree noundef nonnull align 16 addrspacecast ({}* inttoptr (i64 139965165787312 to {}*) to {} addrspace(10)*)) #659, !dbg !47683 %box = call noalias nonnull dereferenceable(8) "enzyme_inactive" {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task3, i64 noundef 8, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 139978038671616 to {}*) to {} addrspace(10)*)) #660, !dbg !47683 %25 = bitcast {} addrspace(10)* %box to [1 x {} addrspace(10)*] addrspace(10)*, !dbg !47683 %26 = extractvalue [1 x {} addrspace(10)*] %24, 0, !dbg !47683 %27 = getelementptr [1 x {} addrspace(10)*], [1 x {} addrspace(10)*] addrspace(10)* %25, i64 0, i64 0, !dbg !47683 store {} addrspace(10)* %26, {} addrspace(10)* addrspace(10)* %27, align 8, !dbg !47683, !tbaa !621, !alias.scope !606, !noalias !47684 %28 = addrspacecast {} addrspace(10)* %box to {} addrspace(12)*, !dbg !47683 call void @ijl_throw({} addrspace(12)* %28) #661, !dbg !47683 unreachable, !dbg !47683 L8: ; preds = %top %29 = addrspacecast {} addrspace(10)* %getfield to {} addrspace(10)* addrspace(11)*, !dbg !47685 %getfield6 = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %29 unordered, align 8, !dbg !47685, !tbaa !602, !alias.scope !606, !noalias !609, !nonnull !590, !dereferenceable !628, !align !615 %30 = addrspacecast {} addrspace(10)* %getfield6 to i8 addrspace(11)*, !dbg !47687 %getfield_addr7 = getelementptr inbounds i8, i8 addrspace(11)* %30, i64 40, !dbg !47687 %31 = bitcast i8 addrspace(11)* %getfield_addr7 to {} addrspace(10)* addrspace(11)*, !dbg !47687 %getfield8 = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %31 unordered, align 8, !dbg !47687, !tbaa !602, !alias.scope !606, !noalias !609, !nonnull !590, !dereferenceable !615, !align !615 %32 = call token (...) @llvm.julia.gc_preserve_begin({} addrspace(10)* nonnull %getfield8) #658, !dbg !47689 %33 = addrspacecast {} addrspace(10)* %getfield8 to {} addrspace(11)*, !dbg !47690 %34 = call nonnull {}* @julia.pointer_from_objref({} addrspace(11)* noundef %33) #662, !dbg !47690 %ptr.i = bitcast {}* %34 to i64*, !dbg !47689 %rv.i = load atomic i64, i64* %ptr.i acquire, align 16, !dbg !47689 call void @llvm.julia.gc_preserve_end(token %32) #658, !dbg !47689 %.not33 = icmp eq i64 %rv.i, 0, !dbg !47692 br i1 %.not33, label %L17, label %L20, !dbg !47688 L17: ; preds = %L8 %35 = call fastcc [1 x {} addrspace(10)*] @julia_ArgumentError_31098({} addrspace(10)* nofree noundef nonnull align 16 addrspacecast ({}* inttoptr (i64 139965165788400 to {}*) to {} addrspace(10)*)) #658, !dbg !47693 %box11 = call noalias nonnull dereferenceable(8) "enzyme_inactive" {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task3, i64 noundef 8, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 139978038671616 to {}*) to {} addrspace(10)*)) #660, !dbg !47693 %36 = bitcast {} addrspace(10)* %box11 to [1 x {} addrspace(10)*] addrspace(10)*, !dbg !47693 %37 = extractvalue [1 x {} addrspace(10)*] %35, 0, !dbg !47693 %38 = getelementptr [1 x {} addrspace(10)*], [1 x {} addrspace(10)*] addrspace(10)* %36, i64 0, i64 0, !dbg !47693 store {} addrspace(10)* %37, {} addrspace(10)* addrspace(10)* %38, align 8, !dbg !47693, !tbaa !621, !alias.scope !606, !noalias !47684 %39 = addrspacecast {} addrspace(10)* %box11 to {} addrspace(12)*, !dbg !47693 call void @ijl_throw({} addrspace(12)* %39) #661, !dbg !47693 unreachable, !dbg !47693 L20: ; preds = %L8 %40 = addrspacecast {} addrspace(10)* %getfield6 to { {} addrspace(10)*, i64, i64, i8 } addrspace(11)*, !dbg !47694 %41 = getelementptr inbounds { {} addrspace(10)*, i64, i64, i8 }, { {} addrspace(10)*, i64, i64, i8 } addrspace(11)* %40, i64 0, i32 0, !dbg !47694 %42 = load {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %41, align 8, !dbg !47694, !tbaa !602, !alias.scope !606, !noalias !609 %43 = addrspacecast {} addrspace(10)* %42 to i8 addrspace(11)*, !dbg !47696 %44 = getelementptr inbounds i8, i8 addrspace(11)* %43, i64 8, !dbg !47696 %45 = load i8, i8 addrspace(11)* %44, align 8, !dbg !47696, !tbaa !602, !alias.scope !606, !noalias !609 %46 = and i8 %45, 1, !dbg !47696 %.not34 = icmp eq i8 %46, 0, !dbg !47696 br i1 %.not34, label %L73, label %L27, !dbg !47698 L27: ; preds = %L20 %47 = call fastcc nonnull align 8 {} addrspace(10)* @julia_context__32398({} addrspace(10)* noundef nonnull align 8 dereferenceable(24) %42) #658, !dbg !47700 store volatile {} addrspace(10)* %42, {} addrspace(10)* addrspace(10)* %phic, align 8, !dbg !47701, !noalias !47672 call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %15, {} addrspace(10)* %42), !dbg !47701 store volatile {} addrspace(10)* %47, {} addrspace(10)* addrspace(10)* %phic1, align 8, !dbg !47701, !noalias !47672 call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %9, {} addrspace(10)* %47), !dbg !47701 store volatile i8 0, i8* %phic19, align 1, !dbg !47701, !tbaa !774, !alias.scope !776, !noalias !47702 %48 = call i64 @ijl_excstack_state() #658, !dbg !47701 %49 = call i32 @julia.except_enter() #663, !dbg !47701 %50 = icmp eq i32 %49, 0, !dbg !47701 br i1 %50, label %try, label %L46, !dbg !47701 L46: ; preds = %L27 %phic.0.phic.0.phic.0.phic.0.phic.0.phic.0.phic.0.phic.0.phic.0.phic.0. = load volatile {} addrspace(10)*, {} addrspace(10)* addrspace(10)* %phic, align 8, !dbg !47703, !nonnull !590 %phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0. = load volatile {} addrspace(10)*, {} addrspace(10)* addrspace(10)* %phic1, align 8, !dbg !47703, !nonnull !590 %phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0. = load volatile i8, i8* %phic19, align 1, !dbg !47703 call void @ijl_pop_handler(i32 noundef 1) #658, !dbg !47703 %51 = and i8 %phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0., 1, !dbg !47703 %phi.cast = icmp ne i8 %51, 0, !dbg !47703 br label %L51, !dbg !47703 L51: ; preds = %try, %L46 %value_phi = phi {} addrspace(10)* [ %42, %try ], [ %phic.0.phic.0.phic.0.phic.0.phic.0.phic.0.phic.0.phic.0.phic.0.phic.0., %L46 ] %value_phi15 = phi {} addrspace(10)* [ %47, %try ], [ %phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0., %L46 ] %value_phi17 = phi i1 [ true, %try ], [ %phi.cast, %L46 ] %52 = addrspacecast {} addrspace(10)* %value_phi15 to {} addrspace(11)*, !dbg !47704 %53 = icmp eq {} addrspace(11)* %52, addrspacecast ({}* inttoptr (i64 139978194116616 to {}*) to {} addrspace(11)*), !dbg !47704 %54 = addrspacecast {} addrspace(10)* %value_phi to {} addrspace(11)* %55 = icmp eq {} addrspace(11)* %52, %54 %or.cond = select i1 %53, i1 true, i1 %55, !dbg !47704 br i1 %or.cond, label %L67, label %L62, !dbg !47704 L62: ; preds = %L51 %56 = addrspacecast {} addrspace(10)* %value_phi15 to i8 addrspace(11)*, !dbg !47705 %57 = getelementptr inbounds i8, i8 addrspace(11)* %56, i64 8, !dbg !47705 %58 = load i8, i8 addrspace(11)* %57, align 8, !dbg !47705, !tbaa !846, !alias.scope !606, !noalias !609 %59 = and i8 %58, 1, !dbg !47705 %.not35 = icmp eq i8 %59, 0, !dbg !47705 br i1 %.not35, label %L67, label %L65, !dbg !47704 L65: ; preds = %L62 %60 = call fastcc nonnull {} addrspace(10)* @julia_context__32398({} addrspace(10)* noundef nonnull align 8 dereferenceable(24) %value_phi15) #658, !dbg !47707 br label %L67, !dbg !47707 L67: ; preds = %L65, %L62, %L51 br i1 %50, label %L71, label %L69, !dbg !47707 L69: ; preds = %L67 call fastcc void @julia_rethrow_31152() #661, !dbg !47707 unreachable, !dbg !47707 L71: ; preds = %L67 br i1 %value_phi17, label %ok, label %err, !dbg !47707 L73: ; preds = %L20 call fastcc void @julia_error_31187({} addrspace(10)* nofree noundef nonnull align 32 addrspacecast ({}* inttoptr (i64 139962719163168 to {}*) to {} addrspace(10)*)) #661, !dbg !47708 unreachable, !dbg !47708 try: ; preds = %L27 %61 = call fastcc i64 @julia_unsafe_convert_32014({} addrspace(10)* nocapture noundef nonnull readonly align 8 dereferenceable(40) %0) #658, !dbg !47709 %62 = addrspacecast {} addrspace(10)* %0 to i8 addrspace(11)*, !dbg !47713 %63 = getelementptr inbounds i8, i8 addrspace(11)* %62, i64 24, !dbg !47713 %aggregate_load_box.sroa.0.0..sroa_idx = bitcast i8 addrspace(11)* %63 to i64 addrspace(11)*, !dbg !47713 %aggregate_load_box.sroa.0.0.copyload = load i64, i64 addrspace(11)* %aggregate_load_box.sroa.0.0..sroa_idx, align 8, !dbg !47713, !tbaa !710, !alias.scope !711, !noalias !47716 %aggregate_load_box.sroa.2.0..sroa_idx25 = getelementptr inbounds i8, i8 addrspace(11)* %62, i64 32, !dbg !47713 %64 = bitcast i8 addrspace(11)* %aggregate_load_box.sroa.2.0..sroa_idx25 to i64 addrspace(11)*, !dbg !47713 %aggregate_load_box.sroa.2.0.copyload = load i64, i64 addrspace(11)* %64, align 8, !dbg !47713, !tbaa !710, !alias.scope !711, !noalias !47716 %65 = mul i64 %aggregate_load_box.sroa.2.0.copyload, %aggregate_load_box.sroa.0.0.copyload, !dbg !47717 call fastcc void @julia_set__33047(i64 zeroext %61, i32 zeroext %bitcast_coercion, i64 signext %65) #658, !dbg !47712 store volatile i8 1, i8* %phic19, align 1, !dbg !47703, !tbaa !774, !alias.scope !776, !noalias !47702 call void @ijl_pop_handler(i32 noundef 1) #658, !dbg !47703 br label %L51, !dbg !47703 err: ; preds = %L71 call void @ijl_undefined_var_error({} addrspace(12)* noundef addrspacecast ({}* inttoptr (i64 139978194630336 to {}*) to {} addrspace(12)*)) #661, !dbg !47707 unreachable, !dbg !47707 ok: ; preds = %L71 ret void, !dbg !47699 } Type analysis state: %current_task3 = bitcast {}*** %current_task329 to {}**: {}, intvals: {} %bitcast_coercion = bitcast float %1 to i32, !dbg !603: {[-1]:Integer}, intvals: {} {} addrspace(10)* addrspacecast ({}* inttoptr (i64 139965165787312 to {}*) to {} addrspace(10)*): {[-1]:Anything}, intvals: {} {}* inttoptr (i64 139965165787312 to {}*): {[-1]:Anything}, intvals: {} %value_phi15 = phi {} addrspace(10)* [ %47, %try ], [ %phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0., %L46 ]: {[-1]:Pointer, [-1,0]:Pointer, [-1,8]:Integer, [-1,16]:Pointer}, intvals: {} %24 = call fastcc [1 x {} addrspace(10)*] @julia_ArgumentError_31098({} addrspace(10)* nofree noundef nonnull align 16 addrspacecast ({}* inttoptr (i64 139965165787312 to {}*) to {} addrspace(10)*)) #659, !dbg !630: {[-1]:Pointer}, intvals: {} %box11 = call noalias nonnull dereferenceable(8) "enzyme_inactive" {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task3, i64 noundef 8, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 139978038671616 to {}*) to {} addrspace(10)*)) #660, !dbg !650: {[-1,-1]:Pointer}, intvals: {} %phic1 = bitcast {} addrspace(10)* %9 to {} addrspace(10)* addrspace(10)*, !enzyme_caststack !590: {[-1]:Pointer, [-1,0]:Pointer}, intvals: {} %13 = bitcast {}** %12 to i8**: {[-1]:Pointer}, intvals: {} %17 = bitcast {}*** %ptls_field30 to i64***: {[-1]:Pointer}, intvals: {} %ptls_load3132 = load i64**, i64*** %17, align 8, !tbaa !596: {}, intvals: {} %15 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) {} addrspace(10)* @julia.gc_alloc_obj({}** %11, i64 8, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 139961738084176 to {}*) to {} addrspace(10)*)), !enzyme_fromstack !591: {}, intvals: {} %60 = call fastcc nonnull {} addrspace(10)* @julia_context__32398({} addrspace(10)* noundef nonnull align 8 dereferenceable(24) %value_phi15) #658, !dbg !673: {}, intvals: {} %11 = getelementptr inbounds {}*, {}** %10, i64 -14: {}, intvals: {} {}* inttoptr (i64 139962719163168 to {}*): {[-1]:Anything}, intvals: {} {} addrspace(10)* addrspacecast ({}* inttoptr (i64 139962719163168 to {}*) to {} addrspace(10)*): {[-1]:Anything}, intvals: {} i64 8: {[-1]:Integer}, intvals: {8,} %9 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) {} addrspace(10)* @julia.gc_alloc_obj({}** %5, i64 8, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 139961738084176 to {}*) to {} addrspace(10)*)), !enzyme_fromstack !591: {}, intvals: {} {}* inttoptr (i64 139961738084176 to {}*): {[-1]:Anything}, intvals: {} %4 = bitcast {}*** %2 to {}**: {}, intvals: {} %61 = call fastcc i64 @julia_unsafe_convert_32014({} addrspace(10)* nocapture noundef nonnull readonly align 8 dereferenceable(40) %0) #658, !dbg !675: {}, intvals: {} %2 = call {}*** @julia.get_pgcstack(): {}, intvals: {} %5 = getelementptr inbounds {}*, {}** %4, i64 -14: {}, intvals: {} %6 = getelementptr inbounds {}*, {}** %5, i64 16: {}, intvals: {} %12 = getelementptr inbounds {}*, {}** %11, i64 16: {}, intvals: {} %14 = load i8*, i8** %13, align 8: {}, intvals: {} %box = call noalias nonnull dereferenceable(8) "enzyme_inactive" {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task3, i64 noundef 8, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 139978038671616 to {}*) to {} addrspace(10)*)) #660, !dbg !630: {[-1,-1]:Pointer}, intvals: {} %safepoint = load i64*, i64** %18, align 8, !tbaa !600: {}, intvals: {} {} addrspace(10)* addrspacecast ({}* inttoptr (i64 139978038671616 to {}*) to {} addrspace(10)*): {[-1]:Anything}, intvals: {} {}* inttoptr (i64 139978038671616 to {}*): {[-1]:Anything}, intvals: {} %35 = call fastcc [1 x {} addrspace(10)*] @julia_ArgumentError_31098({} addrspace(10)* nofree noundef nonnull align 16 addrspacecast ({}* inttoptr (i64 139965165788400 to {}*) to {} addrspace(10)*)) #658, !dbg !650: {[-1]:Pointer}, intvals: {} %phic19 = call noalias nonnull dereferenceable(1) dereferenceable_or_null(1) i8* @malloc(i64 1), !enzyme_fromstack !592: {[-1]:Pointer}, intvals: {} %7 = bitcast {}** %6 to i8**: {[-1]:Pointer}, intvals: {} %8 = load i8*, i8** %7, align 8: {}, intvals: {} %16 = call {}*** @julia.get_pgcstack() #658: {}, intvals: {} %phic = bitcast {} addrspace(10)* %15 to {} addrspace(10)* addrspace(10)*, !enzyme_caststack !590: {[-1]:Pointer, [-1,0]:Pointer}, intvals: {} %3 = call {}*** @julia.get_pgcstack(): {}, intvals: {} %42 = load {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %41, align 8, !dbg !651, !tbaa !613, !alias.scope !617, !noalias !620: {[-1]:Pointer, [-1,0]:Pointer, [-1,8]:Integer, [-1,16]:Pointer}, intvals: {} {} addrspace(10)* addrspacecast ({}* inttoptr (i64 139965165788400 to {}*) to {} addrspace(10)*): {[-1]:Anything}, intvals: {} {}* inttoptr (i64 139965165788400 to {}*): {[-1]:Anything}, intvals: {} %18 = getelementptr inbounds i64*, i64** %ptls_load3132, i64 2: {[-1]:Pointer}, intvals: {} %ptls_field30 = getelementptr inbounds {}**, {}*** %16, i64 2: {}, intvals: {} {} addrspace(10)* null: {[-1]:Pointer, [-1,-1]:Anything}, intvals: {0,} %65 = mul i64 %aggregate_load_box.sroa.2.0.copyload, %aggregate_load_box.sroa.0.0.copyload, !dbg !691: {[-1]:Integer}, intvals: {} %10 = bitcast {}*** %3 to {}**: {}, intvals: {} {} addrspace(10)* addrspacecast ({}* inttoptr (i64 139961738084176 to {}*) to {} addrspace(10)*): {[-1]:Anything}, intvals: {} {} addrspace(10)* %0: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,0]:Pointer, [-1,0,0,0]:Pointer, [-1,0,0,0,0]:Pointer, [-1,0,0,0,8]:Integer, [-1,0,0,0,16]:Pointer, [-1,0,0,16]:Integer, [-1,0,0,17]:Integer, [-1,0,0,18]:Integer, [-1,0,0,19]:Integer, [-1,0,0,20]:Integer, [-1,0,0,21]:Integer, [-1,0,0,22]:Integer, [-1,0,0,23]:Integer, [-1,0,0,24]:Integer, [-1,0,0,32]:Pointer, [-1,0,0,40]:Pointer, [-1,0,0,40,-1]:Integer, [-1,0,8]:Integer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}, intvals: {} float %1: {[-1]:Float@float}, intvals: {} %47 = call fastcc nonnull align 8 {} addrspace(10)* @julia_context__32398({} addrspace(10)* noundef nonnull align 8 dereferenceable(24) %42) #658, !dbg !662: {}, intvals: {} %current_task329 = getelementptr inbounds {}**, {}*** %16, i64 -14: {}, intvals: {} Illegal updateAnalysis prev:{[-1]:Integer} new: {[-1]:Float@float} val: %bitcast_coercion = bitcast float %1 to i32, !dbg !603 origin= %bitcast_coercion = bitcast float %1 to i32, !dbg !603 MethodInstance for fill!(::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Float32) Caused by: Stacktrace: [1] reinterpret @ ./essentials.jl:581 [2] fill! @ ~/.julia/packages/CUDA/jdJ7Z/src/array.jl:829 Stacktrace: [1] julia_error(cstr::Cstring, val::Ptr{…}, errtype::Enzyme.API.ErrorType, data::Ptr{…}, data2::Ptr{…}, B::Ptr{…}) @ Enzyme.Compiler ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:1690 [2] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, width::Int64, additionalArg::Ptr{…}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, augmented::Ptr{…}, atomicAdd::Bool) @ Enzyme.API ~/.julia/packages/Enzyme/2FwRI/src/api.jl:154 [3] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…}, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…}) @ Enzyme.Compiler ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:3177 [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing) @ Enzyme.Compiler ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5070 [5] codegen @ ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:4477 [inlined] [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool) @ Enzyme.Compiler ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5755 [7] _thunk @ ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5755 [inlined] [8] cached_compilation @ ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5793 [inlined] [9] (::Enzyme.Compiler.var"#554#555"{…})(ctx::LLVM.Context) @ Enzyme.Compiler ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5859 [10] JuliaContext(f::Enzyme.Compiler.var"#554#555"{…}; kwargs::@Kwargs{}) @ GPUCompiler ~/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:52 [11] JuliaContext(f::Function) @ GPUCompiler ~/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:42 [12] #s2027#553 @ ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5811 [inlined] [13] @ Enzyme.Compiler ./none:0 [14] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any}) @ Core ./boot.jl:602 [15] autodiff @ ~/.julia/packages/Enzyme/2FwRI/src/Enzyme.jl:286 [inlined]
wsmoses commented 3 months ago

You'll need https://github.com/JuliaGPU/CUDA.jl/pull/2371 and then https://github.com/JuliaPackaging/Yggdrasil/pull/8666. It then hits a cublasscal issue, which I stopped investigating to go get dinner.

mcabbott commented 3 months ago

I think the basic interface needed is a nice gradient function.

Enzyme's own gradient should now do this, as make_zero understands nested structures:

julia> sh = [1f0, 2f0]; nt = (a=sh, b=sh, c=copy(sh));

julia> Enzyme.gradient(Reverse, x -> sum(map(sum, x)), nt)
(a = Float32[2.0, 2.0], b = Float32[2.0, 2.0], c = Float32[1.0, 1.0])

(jl_o1ZBlk) pkg> st Enzyme
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_o1ZBlk/Project.toml`
  [7da242da] Enzyme v0.12.4

The above example doesn't work for me, but I believe function gradient_ez(f, x...) can be deleted to have just this:

for epoch in 1:epochs
    g = Enzyme.gradient(Reverse, m -> loss(m, X, y), model) # Enzyme gradient
    # g = Flux.gradient(model -> loss(model, X, y), model)[1] # Zygote gradient
    Flux.update!(opt_state, model, g)
    report(epoch)
end

A slight problem in make_zero! is that it sets to zero the arrays but not the scalar field, so those are going to be accumulated. That can be fixed later and in principle it is not even a problem since scalars are not updated bu the optimizer.

Right. For those coming from Zygote, it's slightly odd that the gradient contains numbers for non-diff things. But I believe Optimisers.jl's idea of what parameters can be updated is narrow enough that it will only use true gradient numbers from Enzyme.jl.

wsmoses commented 3 months ago

This should be resolved by https://github.com/FluxML/Flux.jl/pull/2446

Like I say in that PR

""" I have no opinions on the design/API and I will give this PR to you all to make it however you feel (and I will go back to staring at CUDA).

I will note that perf atm is unclear and is worth investigating. However, before we do that, having a good way to run/test things is critical, hence this PR. """

wsmoses commented 3 months ago

edit: accidentally reran cpu, please ignore below.

CUDA works on the simple example now. It does require either CUDA#master on already merged branches or hopefully a backport release from CUDA.jl via https://github.com/JuliaGPU/CUDA.jl/pull/2375 as well as a Enzyme_jll bump

wmoses@beast:~/git/Flux.jl ((HEAD detached at origin/master)) $ cat orig.jl 
using CUDA # for GPU training
using Flux, Enzyme
using Random, Statistics

_make_zero(x::Union{Number,AbstractArray}) = zero(x)
_make_zero(x) = x
make_zero(model) = fmap(_make_zero, model)

function gradient_ez(f, x...)
    args = []
    for x in x
        if x isa Number
            push!(args, Active(x))
        else
            push!(args, Duplicated(x, make_zero(x)))
        end
    end
    ret = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...)
    g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
    return g
end

batch_size = 128
feature_size = 784
num_classes = 10
epochs = 100
# device = Flux.cpu        # CPU training
device = Flux.gpu      # GPU training

X = randn(Float32, feature_size, batch_size) |> device
y = Flux.onehotbatch(rand(1:num_classes, batch_size), 1:num_classes) |> device

model = Chain(Dense(feature_size => 32, relu),
              Dense(32, num_classes)) |> device

opt_state = Flux.setup(Adam(1e-3), model)

loss(model, x, y) = Flux.logitcrossentropy(model(x), y)
accuracy(model, x, y) = mean(Flux.onecold(model(x)) .== Flux.onecold(y))

function report(epoch)
    @info "Epoch: $epoch" loss=loss(model, X, y) accuracy=accuracy(model, X, y)
end

report(0)
for epoch in 1:epochs
    g = gradient_ez(model -> loss(model, X, y), model)[1]     # Enzyme gradient
    # g = Flux.gradient(model -> loss(model, X, y), model)[1] # Zygote gradient
    Flux.update!(opt_state, model, g)
    report(epoch)
end
wmoses@beast:~/git/Flux.jl ((HEAD detached at origin/master)) $ ~/git/Enzyme.jl/julia-1.10.2/bin/julia --project orig.jl 
┌ Warning: Package cuDNN not found in current path.
│ - Run `import Pkg; Pkg.add("cuDNN")` to install the cuDNN package, then restart julia.
│ - If cuDNN is not installed, some Flux functionalities will not be available when running on the GPU.
└ @ FluxCUDAExt ~/git/Flux.jl/ext/FluxCUDAExt/FluxCUDAExt.jl:57
┌ Info: Epoch: 0
│   loss = 2.7904227f0
└   accuracy = 0.125
┌ Info: Epoch: 1
│   loss = 2.5142982f0
└   accuracy = 0.15625
┌ Info: Epoch: 2
│   loss = 2.2610319f0
└   accuracy = 0.203125
┌ Info: Epoch: 3
│   loss = 2.029134f0
└   accuracy = 0.28125
┌ Info: Epoch: 4
│   loss = 1.8172197f0
└   accuracy = 0.3515625
┌ Info: Epoch: 5
│   loss = 1.6268556f0
└   accuracy = 0.4375
┌ Info: Epoch: 6
│   loss = 1.4554112f0
└   accuracy = 0.546875
┌ Info: Epoch: 7
│   loss = 1.3014916f0
└   accuracy = 0.6640625
┌ Info: Epoch: 8
│   loss = 1.163165f0
└   accuracy = 0.7890625
┌ Info: Epoch: 9
│   loss = 1.0413302f0
└   accuracy = 0.8515625
┌ Info: Epoch: 10
│   loss = 0.93555194f0
└   accuracy = 0.8515625
┌ Info: Epoch: 11
│   loss = 0.84206563f0
└   accuracy = 0.8828125
┌ Info: Epoch: 12
│   loss = 0.7600569f0
└   accuracy = 0.90625
┌ Info: Epoch: 13
│   loss = 0.6874082f0
└   accuracy = 0.921875
┌ Info: Epoch: 14
│   loss = 0.6230737f0
└   accuracy = 0.9296875
┌ Info: Epoch: 15
│   loss = 0.5663827f0
└   accuracy = 0.9609375
┌ Info: Epoch: 16
│   loss = 0.5165455f0
└   accuracy = 0.96875
┌ Info: Epoch: 17
│   loss = 0.4719535f0
└   accuracy = 0.96875
┌ Info: Epoch: 18
│   loss = 0.4319139f0
└   accuracy = 0.9765625
┌ Info: Epoch: 19
│   loss = 0.39577293f0
└   accuracy = 0.984375
┌ Info: Epoch: 20
│   loss = 0.36347917f0
└   accuracy = 0.984375
┌ Info: Epoch: 21
│   loss = 0.33449084f0
└   accuracy = 0.9921875
┌ Info: Epoch: 22
│   loss = 0.30846184f0
└   accuracy = 0.9921875
┌ Info: Epoch: 23
│   loss = 0.28476223f0
└   accuracy = 0.9921875
┌ Info: Epoch: 24
│   loss = 0.26318714f0
└   accuracy = 1.0
┌ Info: Epoch: 25
│   loss = 0.24353352f0
└   accuracy = 1.0
┌ Info: Epoch: 26
│   loss = 0.22557218f0
└   accuracy = 1.0
┌ Info: Epoch: 27
│   loss = 0.20921068f0
└   accuracy = 1.0
┌ Info: Epoch: 28
│   loss = 0.19429381f0
└   accuracy = 1.0
┌ Info: Epoch: 29
│   loss = 0.18054952f0
└   accuracy = 1.0
┌ Info: Epoch: 30
│   loss = 0.16796987f0
└   accuracy = 1.0
┌ Info: Epoch: 31
│   loss = 0.1563463f0
└   accuracy = 1.0
┌ Info: Epoch: 32
│   loss = 0.14567412f0
└   accuracy = 1.0
┌ Info: Epoch: 33
│   loss = 0.13588753f0
└   accuracy = 1.0
┌ Info: Epoch: 34
│   loss = 0.12687433f0
└   accuracy = 1.0
┌ Info: Epoch: 35
│   loss = 0.11857266f0
└   accuracy = 1.0
┌ Info: Epoch: 36
│   loss = 0.11093213f0
└   accuracy = 1.0
┌ Info: Epoch: 37
│   loss = 0.103871785f0
└   accuracy = 1.0
┌ Info: Epoch: 38
│   loss = 0.09736837f0
└   accuracy = 1.0
┌ Info: Epoch: 39
│   loss = 0.09138645f0
└   accuracy = 1.0
┌ Info: Epoch: 40
│   loss = 0.08586908f0
└   accuracy = 1.0
┌ Info: Epoch: 41
│   loss = 0.080786735f0
└   accuracy = 1.0
┌ Info: Epoch: 42
│   loss = 0.07610354f0
└   accuracy = 1.0
┌ Info: Epoch: 43
│   loss = 0.07179588f0
└   accuracy = 1.0
┌ Info: Epoch: 44
│   loss = 0.06783663f0
└   accuracy = 1.0
┌ Info: Epoch: 45
│   loss = 0.06419177f0
└   accuracy = 1.0
┌ Info: Epoch: 46
│   loss = 0.060845155f0
└   accuracy = 1.0
┌ Info: Epoch: 47
│   loss = 0.057761367f0
└   accuracy = 1.0
┌ Info: Epoch: 48
│   loss = 0.0549154f0
└   accuracy = 1.0
┌ Info: Epoch: 49
│   loss = 0.05228231f0
└   accuracy = 1.0
┌ Info: Epoch: 50
│   loss = 0.049845647f0
└   accuracy = 1.0
┌ Info: Epoch: 51
│   loss = 0.047589153f0
└   accuracy = 1.0
┌ Info: Epoch: 52
│   loss = 0.045498513f0
└   accuracy = 1.0
┌ Info: Epoch: 53
│   loss = 0.04355742f0
└   accuracy = 1.0
┌ Info: Epoch: 54
│   loss = 0.04175187f0
└   accuracy = 1.0
┌ Info: Epoch: 55
│   loss = 0.04007356f0
└   accuracy = 1.0
┌ Info: Epoch: 56
│   loss = 0.038507923f0
└   accuracy = 1.0
┌ Info: Epoch: 57
│   loss = 0.037045095f0
└   accuracy = 1.0
┌ Info: Epoch: 58
│   loss = 0.035674226f0
└   accuracy = 1.0
┌ Info: Epoch: 59
│   loss = 0.034392048f0
└   accuracy = 1.0
┌ Info: Epoch: 60
│   loss = 0.033194654f0
└   accuracy = 1.0
┌ Info: Epoch: 61
│   loss = 0.032058075f0
└   accuracy = 1.0
┌ Info: Epoch: 62
│   loss = 0.030996136f0
└   accuracy = 1.0
┌ Info: Epoch: 63
│   loss = 0.02999451f0
└   accuracy = 1.0
┌ Info: Epoch: 64
│   loss = 0.029050402f0
└   accuracy = 1.0
┌ Info: Epoch: 65
│   loss = 0.02815985f0
└   accuracy = 1.0
┌ Info: Epoch: 66
│   loss = 0.027319008f0
└   accuracy = 1.0
┌ Info: Epoch: 67
│   loss = 0.02652272f0
└   accuracy = 1.0
┌ Info: Epoch: 68
│   loss = 0.025767544f0
└   accuracy = 1.0
┌ Info: Epoch: 69
│   loss = 0.025051065f0
└   accuracy = 1.0
┌ Info: Epoch: 70
│   loss = 0.024369944f0
└   accuracy = 1.0
┌ Info: Epoch: 71
│   loss = 0.023721226f0
└   accuracy = 1.0
┌ Info: Epoch: 72
│   loss = 0.023103705f0
└   accuracy = 1.0
┌ Info: Epoch: 73
│   loss = 0.022514593f0
└   accuracy = 1.0
┌ Info: Epoch: 74
│   loss = 0.021952922f0
└   accuracy = 1.0
┌ Info: Epoch: 75
│   loss = 0.021417053f0
└   accuracy = 1.0
┌ Info: Epoch: 76
│   loss = 0.020906389f0
└   accuracy = 1.0
┌ Info: Epoch: 77
│   loss = 0.0204159f0
└   accuracy = 1.0
┌ Info: Epoch: 78
│   loss = 0.01994732f0
└   accuracy = 1.0
┌ Info: Epoch: 79
│   loss = 0.01949887f0
└   accuracy = 1.0
┌ Info: Epoch: 80
│   loss = 0.01906871f0
└   accuracy = 1.0
┌ Info: Epoch: 81
│   loss = 0.018656129f0
└   accuracy = 1.0
┌ Info: Epoch: 82
│   loss = 0.018260362f0
└   accuracy = 1.0
┌ Info: Epoch: 83
│   loss = 0.017879806f0
└   accuracy = 1.0
┌ Info: Epoch: 84
│   loss = 0.017513612f0
└   accuracy = 1.0
┌ Info: Epoch: 85
│   loss = 0.017161498f0
└   accuracy = 1.0
┌ Info: Epoch: 86
│   loss = 0.01682241f0
└   accuracy = 1.0
┌ Info: Epoch: 87
│   loss = 0.016495718f0
└   accuracy = 1.0
┌ Info: Epoch: 88
│   loss = 0.016181245f0
└   accuracy = 1.0
┌ Info: Epoch: 89
│   loss = 0.015877243f0
└   accuracy = 1.0
┌ Info: Epoch: 90
│   loss = 0.0155781405f0
└   accuracy = 1.0
┌ Info: Epoch: 91
│   loss = 0.01528422f0
└   accuracy = 1.0
┌ Info: Epoch: 92
│   loss = 0.014997441f0
└   accuracy = 1.0
┌ Info: Epoch: 93
│   loss = 0.014718127f0
└   accuracy = 1.0
┌ Info: Epoch: 94
│   loss = 0.014446221f0
└   accuracy = 1.0
┌ Info: Epoch: 95
│   loss = 0.014181806f0
└   accuracy = 1.0
┌ Info: Epoch: 96
│   loss = 0.013925277f0
└   accuracy = 1.0
┌ Info: Epoch: 97
│   loss = 0.013677116f0
└   accuracy = 1.0
┌ Info: Epoch: 98
│   loss = 0.013437184f0
└   accuracy = 1.0
┌ Info: Epoch: 99
│   loss = 0.013204632f0
└   accuracy = 1.0
┌ Info: Epoch: 100
│   loss = 0.012979296f0
└   accuracy = 1.0
mashu commented 3 months ago

@CarloLucibello this gradient_ez is very useful. Thanks! Would it be possible to have also option to run Enzyme from Zygote? Or an example similar to that one with gradient_ez how to add Zygote.@adjoint such that for one custom Flux layer instead of Zygote, Enzyme is used, but the rest is still Zygote?

I am thinking of some way, we could smoothly transition without switching to one completely?

gdalle commented 2 months ago

The most recent attempt was supposed to be DI.jl, but the choice to focus on arrays and single inputs means we can't use it.

@darsnack I'd actually love to revisit the dream of DI + Flux one of these days.

To me the best option would be a Flux.gradient (and Flux.withgradient) that uses ADTypes.jl (only to avoid further fragmentation). Alternatively, a small package that wraps Enzyme.autodiff + make_zero in a Zygote-like interface (similar to what's above).

Why not create a package named DifferentiationInterfaceForFlux or something, which relies on DI but tests compatibility with Flux layers and makes it part of its API? In other words, if I change something in DI that removes compatibility with Flux layers, the glue package could still be frozen to its current version until it gets resolved.