gdalle / DifferentiationInterface.jl

An interface to various automatic differentiation backends in Julia.
https://gdalle.github.io/DifferentiationInterface.jl/DifferentiationInterface
MIT License
169 stars 13 forks source link

Testing NNLib / Lux / Flux #105

Open gdalle opened 6 months ago

gdalle commented 6 months ago

Lower hanging fruit: NNLib.jl, because there are less weird structs, mostly arrays

Cross-referencing:

gdalle commented 6 months ago

Slow

Fast

avik-pal commented 6 months ago

If we want to be adventurous, you can change https://github.com/LuxDL/LuxTestUtils.jl and all downstream CPU tests in Lux will be triggered (and we just need to copy one of the buildkite files from LuxLib to trigger the CUDA + AMDGPU tests)

gdalle commented 6 months ago

Don't tempt me Avik

avik-pal commented 6 months ago

On a serious note though, I had to write it to mostly deal with arrays or at least convert structures to arrays https://github.com/LuxDL/LuxTestUtils.jl/blob/143a51f0d2fb4cbc75ea583c706ff5194be103d2/src/LuxTestUtils.jl#L387-L398, so that could be helpful to writing your test suite. (But this is also terribly inefficient and only tests correctness and definitely don't combine @test_gradients with @jet)

gdalle commented 6 months ago

Are the tests of LuxTestUtils already interesting to run locally, or should we wait for the Downstream CI every time?

avik-pal commented 6 months ago

no the tests there do nothing practically, it is all via the downstream CI

avik-pal commented 6 months ago

but the Lux test suite doesn't take long -- 10 mins on a nicer machine (like the buildkite ones) but github actions ones take longer ~30 mins

If you want to test locally, set RETESTITEMS_NWORKERS and it will be much faster

gdalle commented 6 months ago

So the workflow is to:

  1. fork LuxTestUtils.jl and Lux.jl
  2. put my own gradient callers in LuxTestUtils.jl
  3. dev LuxTestUtils.jl into the test environment of Lux.jl
  4. test Lux.jl

right?

avik-pal commented 6 months ago

If you want to test locally yes.

gdalle commented 6 months ago

Any suggestions on dealing with multiple arguments? Is wrapping them in a ComponentVector always gonna work, or are there non-array structs in the mix?

gdalle commented 6 months ago

DifferentiationInterface only accepts a single input

gdalle commented 6 months ago

I'm thinking https://docs.julialang.org/en/v1/base/base/#Base.splat on a ComponentVector

avik-pal commented 6 months ago

Based on how the tests are written, for multiple arguments, I assume any non-array is non-differentiable (this is a testing package so I can assume that) so these get filtered out in https://github.com/LuxDL/LuxTestUtils.jl/blob/143a51f0d2fb4cbc75ea583c706ff5194be103d2/src/LuxTestUtils.jl#L357-L383. After that there are 2 possibilities -- 1) backend supports multi args so in that case it just forwards it 2) all other cases use a componentarray and create a closure which unflattens the componentarray to provide the correct args.

gdalle commented 6 months ago

I'll see what I can do once our own testing interface stabilizes. Step one would be to replace your gradient calls, but we can actually aim to replace your entire testing macro

gdalle commented 6 months ago

Our function https://gdalle.github.io/DifferentiationInterface.jl/dev/api/#DifferentiationInterfaceTest.test_differentiation does something very similar

avik-pal commented 6 months ago

I'll see what I can do once our own testing interface stabilizes. Step one would be to replace your gradient calls, but we can actually aim to replace your entire testing macro

correct. I had planned to replace the API with something like skip = [AutoTracker(), ...] and broken = [AutoReverseDiff()...]. But eventually we might use DI