Chemellia / ChemistryFeaturization.jl

Interface package for featurizing atomic structures
https://chemistryfeaturization.chemellia.org/dev/
MIT License
41 stars 14 forks source link

Graph AD for Forces #119

Open DhairyaLGandhi opened 2 years ago

DhairyaLGandhi commented 2 years ago

ref #111

x-ref: https://github.com/Chemellia/ChemistryFeaturization.jl/pull/111#issuecomment-922109332, #10

Okay, so we have the following working.

julia> using ChemistryFeaturization, Zygote, Xtals

julia> cr_path = normpath(joinpath(pathof(Xtals), "..", "..", "test", "data", "crystals"))
"/Users/dhairyagandhi/.julia/packages/Xtals/Kf4en/test/data/crystals"

julia> using ChemistryFeaturization.Utils.GraphBuilding

julia> c = Crystal(joinpath(cr_path, "IRMOF-1.cif"))
Name: /Users/dhairyagandhi/.julia/packages/Xtals/Kf4en/test/data/crystals/IRMOF-1.cif
Bravais unit cell of a crystal.
    Unit cell angles α = 90.000000 deg. β = 90.000000 deg. γ = 90.000000 deg.
    Unit cell dimensions a = 25.832000 Å. b = 25.832000 Å, c = 25.832000 Å
    Volume of unit cell: 17237.492730 ų

    # atoms = 424
    # charges = 0
    chemical formula: Dict(:Zn => 4, :H => 12, :O => 13, :C => 24)
    space Group: P1
    symmetry Operations:
        'x, y, z'

julia> gradient(c) do c
         w, s = build_graph2(c)
         sum(w)
       end
((name = nothing, box = (a = -31.83244099108285, b = -31.832440991082862, c = -31.832440991082862, α = -2.1514963693844834e-14, β = -2.8035019022818323e-14, γ = -3.751262522266638e-14, Ω = nothing, f_to_c = nothing, c_to_f = nothing, reciprocal_lattice = nothing), atoms = (n = nothing, species = nothing, coords = (xf = [-4.78645247969116 4.786452479691137 … 51.735173803203864 51.735173803203864; 4.786452479691156 -4.786452479691122 … 77.23782826208793 -77.23782826208932; 4.786452479691155 4.7864524796911665 … -77.23782826208932 77.23782826208792],)), charges = nothing, bonds = nothing, symmetry = nothing),)

which are the gradients for the crystal and graph AD 🎉

Couple of notes:

  1. There were a few missing methods in Xtals to calculate distance, I added those https://github.com/SimonEnsemble/Xtals.jl/pull/101 (and also included a copy in the PR)
  2. replicate is written with a lot of mutation, and I figured we could write a simpler version, so I did. How do we want to handle this? We could propose this change to Xtals.jl ~(the performance is roughly same, but of course this allocates)~ (this implementation is roughly 3x faster)
  3. I have refactored some of the indexing utilities and removed them from AD to save some time there. It means we can write it in a neater way if desired.
  4. How do we want to test this? Do we have some cases in mind?
  5. ~I might have messed up the rebase, so I'll have to move this to a different PR.~

Also, we might have some performance on the table if we optimise this further going by the number of allocations.

julia> @btime gradient($c) do c
         w, s = build_graph(c)
         sum(w)
       end;
  56.376 ms (572641 allocations: 50.27 MiB)
DhairyaLGandhi commented 2 years ago

Another piece of good news is that there is very little involved for AD, for the most part we have refactored some functions and added a couple methods that would be good for Xtals.jl to have anyway.

rkurchin commented 2 years ago

It seems like some of the additional tests are failing on this at the moment?

DhairyaLGandhi commented 2 years ago

Yeah I'm not sure what's changed here. The weights_cutoff function is the same as before.

rkurchin commented 2 years ago

The way things are getting reorganized, this code will probably end up belonging in AtomGraphs (should be registered in a couple days) rather than here, JFYI