LuxDL / Lux.jl

Elegant & Performant Scientific Machine Learning in Julia
MIT License
474 stars 57 forks source link

DifferentiationInterface testing #769

Open gdalle opened 1 month ago

gdalle commented 1 month ago

Hi @avik-pal! I'm heading towards multi-argument and non-array support in DI, and I'd like to start testing Lux layers. For this I would need two things:

Do you think you could help me out?

avik-pal commented 1 month ago
gdalle commented 1 month ago

Thanks! In the file you linked, the source of truth seems to be Zygote? What would you use to validate the Zygote gradients themselves?

avik-pal commented 1 month ago

Currently I compute with Zygote and then test against other backends based on the device

  1. For CPU
    1. Tracker
    2. ReverseDiff
    3. ForwardDiff (if the array sizes are < 100)
    4. FiniteDifferences
    5. Enzyme (currently only tested in that file, but testing is being increased more here)
  2. For GPU
    1. Tracker
    2. ForwardDiff in certain situations depending on the problem

Tracker and Zygote hit very different code paths in LuxLib (Zygote is the optimized one with often handwritten rules). In case of conflict/mismatch, the general assumption is that Tracker (on GPU) or FiniteDifferences (on CPU) is the source of truth.

avik-pal commented 1 month ago

I have been meaning to try out FiniteDiff to validate the GPU gradients but haven't had the time to set it up.

gdalle commented 1 month ago

Hi @avik-pal, has a first version of Lux tests with ComponentArrays.jl encoding. A few points remain unclear to me:

avik-pal commented 1 month ago

When you say the ground truth is finite differences, do you mean FiniteDifferences.jl or FiniteDiff.jl? I only found the latter in LuxTestUtils.jl

I meant FiniteDifferences.jl originally. But you are looking at the new release where I migrated to FiniteDiff 😅

How do you handle the flattening and unflattening of ps for finite differences? The only code I found is this one, should I copy it inside DITest?

That +

Because of DI's single-argument limitation, at the moment I close over (model, x, st) for the loss. Should I deepcopy(st) between two calls, in order to avoid state evolution?

No states cannot be mutated. (It is a bug in Lux if it happens for any of the layers). 1 particular case to be careful about is TaskLocalRNG. We do print a warning but it is easy to miss. Since we cannot copy a TaskLocalRNG it is impossible to guarantee same results across multiple calls there. The recommended solution is to use Xoshiro in regular code and StableRNG in test code.

gdalle commented 1 month ago

How do you choose FiniteDiff parameters like the epsilon? Keep the package defaults? I'm looking at test failures in which I think are due to the numerical errors in finite differencing, but I don't know if I should take even higher atol and rtol (I'm using FiniteDifferences.jl at the time of writing).

Are there layers that contain an rng? I don't see you handling this specifically in the Enzyme tests you pointed me to.

avik-pal commented 1 month ago

How do you choose FiniteDiff parameters like the epsilon? Keep the package defaults?

Yes keep the default. Normalization layers are tricky to test with Finite Differencing especially because of the reason you cited. In those cases, I rely on comparing Zygote with any of the other AD backends. For example, using Tracker hits generic codepaths and using Zygote hits optimized codepaths with custom rrules, so the assumption is that Tracker without custom rules gets the gradient correct. Alternatively for smaller systems comparing against ForwardDiff is also an option.

Are there layers that contain an rng? I don't see you handling this specifically in the Enzyme tests you pointed me to.

The rng passed to Lux.setup(rng, model) is cached with some ugly tricks to make reproducibility works. Make sure this rng is not TaskLocalRNG and you are good.

gdalle commented 1 month ago

Have you ever encountered this error with Tracker + ComponentArrays? The tests pass with Zygote so I'm trying to add more backends

avik-pal commented 1 month ago

Yes you should call Tracker.param on the ComponentArray directly instead of NamedTuple --> Tracker.param --> ComponentArray