JuliaMolSim / DFTK.jl

Density-functional toolkit
https://docs.dftk.org
MIT License
443 stars 89 forks source link

[WIP] Autodiff stresses #443

Closed niklasschmitz closed 3 years ago

niklasschmitz commented 3 years ago

Opening this to keep track of progress on obtaining stresses via autodiff.

Goal

Calculate the stress as the total derivative of the total energy wrt lattice parameters via automatic differentiation. As this falls under scope of the Hellmann-Feynman theorem, we do not need to differentiate through the full SCF solve but rather only through a post-processing on the final solution scfres. We start with the following minimal example of silicon with a single scalar lattice parameter a

using DFTK
using Test

function make_basis(a)
    lattice = a / 2 * [[0 1 1.];
                    [1 0 1.];
                    [1 1 0.]]
    Si = ElementPsp(:Si, psp=load_psp("hgh/lda/Si-q4"))
    atoms = [Si => [ones(3)/8, -ones(3)/8]]
    model = model_atomic(lattice, atoms, symmetries=false)
    kgrid = [1, 1, 1]  # k-point grid (Regular Monkhorst-Pack grid)
    Ecut = 15          # kinetic energy cutoff in Hartree
    PlaneWaveBasis(model, Ecut; kgrid=kgrid)
end

a = 10.26
scfres = self_consistent_field(make_basis(a), tol=1e-8)

function compute_energy(scfres_ref, a)
    basis = make_basis(a)
    energies, H = energy_hamiltonian(basis, scfres_ref.ψ, scfres_ref.occupation; ρ=scfres_ref.ρ)
    energies.total
end

function compute_stress(scfres_ref, a)
    Inf # TODO implement
end
@test compute_stress(scfres, a) ≈ FiniteDiff.finite_difference_derivative(a -> compute_energy(scfres, a), a) # -1.411

Approach

We plan to try ForwardDiff.jl, ReverseDiff.jl and Zygote.jl. For stresses only (#params < 10) we expect ForwardDiff to perform best. Going further the reverse modes of ReverseDiff and Zygote are also interesting as they could jointly evaluate stresses and other derivatives of the total energy (eg. forces) more efficiently.

Expected challenges:

Progress

Related links

An overview of AD tools in Julia: https://juliadiff.org/ Chris Rackauckas on strengths and weaknesses of different AD packages: https://discourse.julialang.org/t/state-of-automatic-differentiation-in-julia/43083/3 Common patterns that need rules in Zygote: https://juliadiff.org/ChainRulesCore.jl/stable/writing_good_rules.html

niklasschmitz commented 3 years ago

cc #107

antoine-levitt commented 3 years ago

Cf also https://github.com/JuliaMolSim/DFTK.jl/issues/47 (in particular a trick to reduce to 6 instead of 9 DOF, although we probably don't care too much)

antoine-levitt commented 3 years ago

I'd focus on forwarddiff for now. We should be able to work around the errors. Michael did the work of making it work for IntervalArithmetic scalar types so it should hopefully be similar.

mfherbst commented 3 years ago

yes please let me know if you need any help on that. I should easily find some nice examples to get you going in case you need any.

niklasschmitz commented 3 years ago

Thanks, starting with ForwardDiff sounds good to me. As I understand it, there's the options of either 1) using GenericLinearAlgebra fallbacks as IntervalArithmetic (currently fails with stacktrace) or 2) directly overloading the relevant LinearAlgebra and FFTW calls on ForwardDiff dual number types (in the spirit of https://github.com/JuliaDiff/ForwardDiff.jl/issues/111#issuecomment-638251496 and https://github.com/JuliaDiff/ForwardDiff.jl/pull/495/files).

My guess would be that both are similar in difficulty but 2) should be preferable for performance, what are your thoughts?

antoine-levitt commented 3 years ago

Performance is not the foremost issue so following IntervalArithmetic sounds good, however we've had quite a bit of issues with the generic FFTs (which are buggy and not actively developed), so if https://github.com/JuliaDiff/ForwardDiff.jl/pull/495/files does the job then great!

mfherbst commented 3 years ago

Regarding the stacktrace: Only some FFT sizes work for the generic implementation we have and unless you specify an fft_size explicitly the PlaneWaveBasis constructor will auto-adjust. Therefore the effective fft_size which is used in https://gist.github.com/niklasschmitz/e7030b3f6341bcf56538a87d0b91d5e1#file-stress-genericlinearalgebra-jl-L30 and in https://gist.github.com/niklasschmitz/e7030b3f6341bcf56538a87d0b91d5e1#file-stress-genericlinearalgebra-jl-L16 don't agree. The solution is to explicitly pass a fixed fft_size to both constructors, e.g. just say PlaneWaveBasis(model, Ecut; kgrid=kgrid, fft_size=[32, 32, 32]) in both lines.

mfherbst commented 3 years ago

But I agree with Antoine. The generic FFT stuff only works "Mäh", so if we can avoid it, that would probably the better solution long-term.

antoine-levitt commented 3 years ago

Tricky bug, nice catch!

niklasschmitz commented 3 years ago

Some updates on both ForwardDiff approaches:

I have iterated on the examples as discussed:

  1. using generic arithmetic on FourierTransforms.jl :
  1. adding ForwardDiff.Dual rules on AbstractFFTs / FFTW:

The inclusion of the AtomicNonLocal() term currently leads to NaN derivative results with ForwardDiff in both approaches, while other terms seem to work without further errors at least

mfherbst commented 3 years ago

Cool that's great news! So we can actually use finite diff to debug the AtomicNonLocal term. Some ideas how to debug:

Regarding the stacktraces in the second PR ... it appears at least for reverse diff this happens already in the PWBasis setup. I don't really fully get why on a first glance. Let's discuss tmr.

antoine-levitt commented 3 years ago

https://github.com/JuliaLang/julia/issues/27705 has a snippet for yielding an error when a NaN is produced

niklasschmitz commented 3 years ago

We've found the NaN of AtomicNonLocal, it came in due to a bug/inconsistency of ForwardDiff on norm(a::StaticArray) of zero's also shown here: https://github.com/JuliaDiff/ForwardDiff.jl/issues/243#issuecomment-369948031. The first very ad-hoc fix I applied was to manually fall back to the norm on Vector which gave correct directional derivatives at [0.,0.,0.] https://github.com/niklasschmitz/DFTK.jl/pull/2#discussion_r654171218.

This fixes the stress of AtomicNonlocal for both ForwardDiff approaches, which also each agree with FiniteDiff.

On Approach 2 I also re-enabled the fft normalizations and added the required additional Dual rule for ScaledPlan. After this now both above both ForwardDiff approaches finally agree on the stress of the example system above!

antoine-levitt commented 3 years ago

Interesting. Actually this is a structural zero, ie it comes about by recip_lattice * zeros(3). So norm always gets called on a vector of 0+eps 0, so the non differentiability of norm at zero is not an issue (at least for forward). Can you check it's OK with chainrules? If yes might as well do a quick workaround here and wait for the next gen of forward diff tools.

niklasschmitz commented 3 years ago

This is the current behavior of norm at zero using (Zygote+ChainRules, ForwardDiff) x (Vector, SVector)

using Zygote
using ForwardDiff
using StaticArrays
using LinearAlgebra

x = zeros(3)
Zygote.gradient(norm, x)[1]
# 3-element Vector{Float64}:
#  0.0
#  0.0
#  0.0
ForwardDiff.gradient(norm, x)
# 3-element Vector{Float64}:
#  0.0
#  0.0
#  1.0

y = @SVector zeros(3)
Zygote.gradient(norm, y)[1]
# 3-element SVector{3, Float64} with indices SOneTo(3):
#  0.0
#  0.0
#  0.0
ForwardDiff.gradient(norm, y)
# 3-element SVector{3, Float64} with indices SOneTo(3):
#  NaN
#  NaN
#  NaN

# [f6369f11] ForwardDiff v0.10.18
# [90137ffa] StaticArrays v1.2.3
# [e88e6eb3] Zygote v0.6.13

For our use case all results are ok except the NaN since it doesn't cancel out in subsequent multiplication by zero, although I'm surprised by ForwardDiff.gradient(norm, x) giving preference to the last input dimension. Zygote picks up on the dedicated rulesets for norm in ChainRules. Calling ChainRules directly also works (in particular the frule gives a consistent 0.0 sensitivity for all input dims)

using ChainRules # [082447d4] ChainRules v0.8.13

ChainRules.unthunk(ChainRules.rrule(norm, x)[2](1.0)[2])
# 3-element Vector{Float64}:
#  0.0
#  0.0
#  0.0

ChainRules.unthunk(ChainRules.rrule(norm, y)[2](1.0)[2])
# 3-element SVector{3, Float64} with indices SOneTo(3):
#  0.0
#  0.0
#  0.0

function onehot(i, n)
    x = zeros(n)
    x[i] = 1.0
    x
end
ChainRules.frule((ChainRules.NoTangent(), onehot(1,3),), norm, x) # (0.0, 0.0)
ChainRules.frule((ChainRules.NoTangent(), onehot(2,3),), norm, x) # (0.0, 0.0)
ChainRules.frule((ChainRules.NoTangent(), onehot(3,3),), norm, x) # (0.0, 0.0)

ChainRules.frule((ChainRules.NoTangent(), onehot(1,3),), norm, y) # (0.0, 0.0)
ChainRules.frule((ChainRules.NoTangent(), onehot(2,3),), norm, y) # (0.0, 0.0)
ChainRules.frule((ChainRules.NoTangent(), onehot(3,3),), norm, y) # (0.0, 0.0)

So a next gen forward diff picking up on ChainRules should indeed fix the problem. As for quick workarounds I'm thinking of either

  1. convert structural zero static-vectors to Vector (as I did during debugging nonlocal.jl), or
  2. overload norm on (static-)vectors of ForwardDiff.Dual to use the corresponding frule (which might need some thinking about how it gets picked up under broadcasting of norm too)
antoine-levitt commented 3 years ago

Yeah, that just looks like a forwarddiff bug, so either work around it locally or fix it upstream.

antoine-levitt commented 3 years ago

done in https://github.com/JuliaMolSim/DFTK.jl/pull/476