Argonne-National-Laboratory / Checkpointing.jl

Checkpointing for Automatic Differentiation
MIT License
52 stars 1 forks source link

Checkpointing changes derivative output #52

Open swilliamson7 opened 1 month ago

swilliamson7 commented 1 month ago

Seemingly, Checkpointing might be causing me to get an incorrect derivative with Enzyme. In this folder I have the script wrong_derivative.jl that can be run (will need the whole repo, it's not minimized) and depending on if I use Checkpointing or not I get different derivatives. The script first uses Enzyme to compute the derivative of my code and then compares the result to finite differences

To change if Checkpointing is used or not you can change these lines https://github.com/swilliamson7/ShallowWaters.jl/blob/3ee9521e379e24c05738edbce05e8a04d00b88de/wrong_derivative_bug/wrong_derivative.jl#L76-L77. Currently it's setup to not use Checkpointing, which leads to this output:

julia> diffs
13-element Vector{Any}:
  -0.13229150858422145
  -0.1321590203831306
  -0.1320253227665641
  -0.13188559394170624
   ⋮
  -6.129973371571395
  30.958481147536077
 587.9892250959529

julia> enzyme_deriv
-0.13161094f0

where diffs is the result of using finite differences and enzyme_deriv is the same derivative computed with Enzyme. If I change my code to instead use Checkpointing I see

julia> diffs
13-element Vector{Any}:
  -0.13229150858422145
  -0.1321590203831306
  -0.1320253227665641
  -0.13188559394170624
   ⋮
  -6.129973371571395
  30.958481147536077
 587.9892250959529

julia> enzyme_deriv
-0.015146092f0

and the derivative Enzyme found no longer matches the result of finite differences. Any help to figure out what's happening would be much appreciated!

swilliamson7 commented 1 month ago

Pinging @michel2323 about this 🙃