FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.46k stars 603 forks source link

DifferentiationInterface testing #2469

Open gdalle opened 1 month ago

gdalle commented 1 month ago

Hi there! I'm heading towards multi-argument and non-array support in DI, and I'd like to start testing Lux layers. For this I would need two things:

Do you think you could help me out?

CarloLucibello commented 1 month ago

You can take a look at the tests we added for Enzyme

https://github.com/FluxML/Flux.jl/blob/master/test/ext_enzyme/enzyme.jl

e.g. begin with

x = rand(Float32, 2, 1)
model = Chain(Dense(2=>3, relu), Dense(3=>2))
g = gradient(model -> sum(model(x)), model)[1]

We impose little limitations on gradients, they can be nested structs or named structs.
For instance, the ones returned by Enzyme and the ones returned by Zygote are compared by

function test_grad(g1, g2; broken=false)
    fmap_with_path(g1, g2) do kp, x, y
        :state ∈ kp && return # ignore RNN and LSTM state
        if x isa AbstractArray{<:Number}
            # @show kp
            @test x ≈ y rtol=1e-2 atol=1e-6 broken=broken
        end
        return x
    end
end

where fmap_with_path is defined in Functors.jl. So what we need is a gradient for each numerical array leaf in the original object. These leaves should be reachable through the same "path", e.g. g.layers[1].weight.

gdalle commented 1 month ago

I'm having issues when comparing the true gradients with finite differences. Depending on the random seed I get unpredictable failures. Is that a problem in the Flux test suite as well @CarloLucibello? I didn't find a way to pass an rng to the network constructors, do I have to seed! the global rng? For now I have increased atol and rtol but it's hard to know the right threshold.

CarloLucibello commented 1 month ago

We do Random.seed!(0) in runtests.jl and we don't see test failures, but I would have expected the tests to be robust. Can you identify the frail ones? Maybe the ones with RNNs?

gdalle commented 1 month ago

I'll try! Which backends should I aim to test? Zygote, Enzyme and Tracker?

CarloLucibello commented 1 month ago

We don't support Tracker anymore. Primarly Zygote, and experimentally Enzyme.

gdalle commented 1 month ago

I added a random seed in https://github.com/gdalle/DifferentiationInterface.jl/pull/371, tests seem to pass for Zygote with the same tolerances as you. I'll notify you if I see random failures further down the road.

Any idea why Enzyme fails on two scenarios only (see the PR for details)?