Closed niklasschmitz closed 3 years ago
cc #107
Cf also https://github.com/JuliaMolSim/DFTK.jl/issues/47 (in particular a trick to reduce to 6 instead of 9 DOF, although we probably don't care too much)
I'd focus on forwarddiff for now. We should be able to work around the errors. Michael did the work of making it work for IntervalArithmetic scalar types so it should hopefully be similar.
yes please let me know if you need any help on that. I should easily find some nice examples to get you going in case you need any.
Thanks, starting with ForwardDiff sounds good to me. As I understand it, there's the options of either 1) using GenericLinearAlgebra fallbacks as IntervalArithmetic (currently fails with stacktrace) or 2) directly overloading the relevant LinearAlgebra and FFTW calls on ForwardDiff dual number types (in the spirit of https://github.com/JuliaDiff/ForwardDiff.jl/issues/111#issuecomment-638251496 and https://github.com/JuliaDiff/ForwardDiff.jl/pull/495/files).
My guess would be that both are similar in difficulty but 2) should be preferable for performance, what are your thoughts?
Performance is not the foremost issue so following IntervalArithmetic sounds good, however we've had quite a bit of issues with the generic FFTs (which are buggy and not actively developed), so if https://github.com/JuliaDiff/ForwardDiff.jl/pull/495/files does the job then great!
Regarding the stacktrace: Only some FFT sizes work for the generic implementation we have and unless you specify an fft_size
explicitly the PlaneWaveBasis
constructor will auto-adjust. Therefore the effective fft_size
which is used in https://gist.github.com/niklasschmitz/e7030b3f6341bcf56538a87d0b91d5e1#file-stress-genericlinearalgebra-jl-L30 and in https://gist.github.com/niklasschmitz/e7030b3f6341bcf56538a87d0b91d5e1#file-stress-genericlinearalgebra-jl-L16 don't agree. The solution is to explicitly pass a fixed fft_size
to both constructors, e.g. just say PlaneWaveBasis(model, Ecut; kgrid=kgrid, fft_size=[32, 32, 32])
in both lines.
But I agree with Antoine. The generic FFT stuff only works "Mäh", so if we can avoid it, that would probably the better solution long-term.
Tricky bug, nice catch!
Some updates on both ForwardDiff approaches:
I have iterated on the examples as discussed:
model_atomic
into a make_basis
helperfft_generic.jl
, afterwards I got an answer consistent with FiniteDiff, see https://github.com/niklasschmitz/DFTK.jl/pull/1/filess*ScaledPlan
where s
is a Dual
), see https://github.com/niklasschmitz/DFTK.jl/pull/2/filesThe inclusion of the AtomicNonLocal() term currently leads to NaN derivative results with ForwardDiff in both approaches, while other terms seem to work without further errors at least
Cool that's great news! So we can actually use finite diff to debug the AtomicNonLocal term. Some ideas how to debug:
c-q4.hgh
instead of si-q4.hgh
.Regarding the stacktraces in the second PR ... it appears at least for reverse diff this happens already in the PWBasis setup. I don't really fully get why on a first glance. Let's discuss tmr.
https://github.com/JuliaLang/julia/issues/27705 has a snippet for yielding an error when a NaN is produced
We've found the NaN of AtomicNonLocal, it came in due to a bug/inconsistency of ForwardDiff on norm(a::StaticArray)
of zero's also shown here: https://github.com/JuliaDiff/ForwardDiff.jl/issues/243#issuecomment-369948031.
The first very ad-hoc fix I applied was to manually fall back to the norm on Vector
which gave correct directional derivatives at [0.,0.,0.] https://github.com/niklasschmitz/DFTK.jl/pull/2#discussion_r654171218.
This fixes the stress of AtomicNonlocal for both ForwardDiff approaches, which also each agree with FiniteDiff.
On Approach 2 I also re-enabled the fft normalizations and added the required additional Dual rule for ScaledPlan. After this now both above both ForwardDiff approaches finally agree on the stress of the example system above!
Interesting. Actually this is a structural zero, ie it comes about by recip_lattice * zeros(3). So norm always gets called on a vector of 0+eps 0, so the non differentiability of norm at zero is not an issue (at least for forward). Can you check it's OK with chainrules? If yes might as well do a quick workaround here and wait for the next gen of forward diff tools.
This is the current behavior of norm at zero using (Zygote+ChainRules, ForwardDiff) x (Vector, SVector)
using Zygote
using ForwardDiff
using StaticArrays
using LinearAlgebra
x = zeros(3)
Zygote.gradient(norm, x)[1]
# 3-element Vector{Float64}:
# 0.0
# 0.0
# 0.0
ForwardDiff.gradient(norm, x)
# 3-element Vector{Float64}:
# 0.0
# 0.0
# 1.0
y = @SVector zeros(3)
Zygote.gradient(norm, y)[1]
# 3-element SVector{3, Float64} with indices SOneTo(3):
# 0.0
# 0.0
# 0.0
ForwardDiff.gradient(norm, y)
# 3-element SVector{3, Float64} with indices SOneTo(3):
# NaN
# NaN
# NaN
# [f6369f11] ForwardDiff v0.10.18
# [90137ffa] StaticArrays v1.2.3
# [e88e6eb3] Zygote v0.6.13
For our use case all results are ok except the NaN since it doesn't cancel out in subsequent multiplication by zero, although I'm surprised by ForwardDiff.gradient(norm, x)
giving preference to the last input dimension. Zygote picks up on the dedicated rulesets for norm in ChainRules. Calling ChainRules directly also works (in particular the frule
gives a consistent 0.0 sensitivity for all input dims)
using ChainRules # [082447d4] ChainRules v0.8.13
ChainRules.unthunk(ChainRules.rrule(norm, x)[2](1.0)[2])
# 3-element Vector{Float64}:
# 0.0
# 0.0
# 0.0
ChainRules.unthunk(ChainRules.rrule(norm, y)[2](1.0)[2])
# 3-element SVector{3, Float64} with indices SOneTo(3):
# 0.0
# 0.0
# 0.0
function onehot(i, n)
x = zeros(n)
x[i] = 1.0
x
end
ChainRules.frule((ChainRules.NoTangent(), onehot(1,3),), norm, x) # (0.0, 0.0)
ChainRules.frule((ChainRules.NoTangent(), onehot(2,3),), norm, x) # (0.0, 0.0)
ChainRules.frule((ChainRules.NoTangent(), onehot(3,3),), norm, x) # (0.0, 0.0)
ChainRules.frule((ChainRules.NoTangent(), onehot(1,3),), norm, y) # (0.0, 0.0)
ChainRules.frule((ChainRules.NoTangent(), onehot(2,3),), norm, y) # (0.0, 0.0)
ChainRules.frule((ChainRules.NoTangent(), onehot(3,3),), norm, y) # (0.0, 0.0)
So a next gen forward diff picking up on ChainRules should indeed fix the problem. As for quick workarounds I'm thinking of either
Vector
(as I did during debugging nonlocal.jl), orfrule
(which might need some thinking about how it gets picked up under broadcasting of norm too)Yeah, that just looks like a forwarddiff bug, so either work around it locally or fix it upstream.
Opening this to keep track of progress on obtaining stresses via autodiff.
Goal
Calculate the stress as the total derivative of the total energy wrt lattice parameters via automatic differentiation. As this falls under scope of the Hellmann-Feynman theorem, we do not need to differentiate through the full SCF solve but rather only through a post-processing on the final solution
scfres
. We start with the following minimal example of silicon with a single scalar lattice parametera
Approach
We plan to try ForwardDiff.jl, ReverseDiff.jl and Zygote.jl. For stresses only (#params < 10) we expect ForwardDiff to perform best. Going further the reverse modes of ReverseDiff and Zygote are also interesting as they could jointly evaluate stresses and other derivatives of the total energy (eg. forces) more efficiently.
Expected challenges:
Progress
no method matching zero(::String)
(TODO understand stack trace)Related links
An overview of AD tools in Julia: https://juliadiff.org/ Chris Rackauckas on strengths and weaknesses of different AD packages: https://discourse.julialang.org/t/state-of-automatic-differentiation-in-julia/43083/3 Common patterns that need rules in Zygote: https://juliadiff.org/ChainRulesCore.jl/stable/writing_good_rules.html