Chemellia / ChemistryFeaturization.jl

Interface package for featurizing atomic structures
https://chemistryfeaturization.chemellia.org/dev/
MIT License
41 stars 14 forks source link

Towards graph AD for Forces #111

Closed DhairyaLGandhi closed 2 years ago

DhairyaLGandhi commented 3 years ago

part of closing #10

And its fast, not that FiniteDifferences is supposed to be "fast" but ForwardDiff fails here, and FiniteDifferences is supposed to be correct.


# FiniteDifferences
julia> @time grad(forward_fdm(2,1),
           (i,j,dist) -> sum(GraphBuilding.weights_cutoff(i,j,dist)),
           collect(1:1000), collect(1:1000), Float64.(collect(1:1000)) );
 98.433196 seconds (942.86 k allocations: 270.869 GiB, 20.58% gc time, 4.76% compilation time)

# Zygote
julia> @btime gradient($(collect(1:1000)), $(collect(1:1000)), $(Float64.(collect(1:1000)))) do i, j, dist
         sum(GraphBuilding.weights_cutoff(i, j, dist))
       end;
  7.407 ms (4164 allocations: 31.16 MiB)
DhairyaLGandhi commented 3 years ago

I've to move some of these to Zygote, since there seem to be a couple of edge cases.

rkurchin commented 3 years ago

@DhairyaLGandhi are you okay for me to do a few of these formatting changes and merge this, or were there still things you wanted to move over to Zygote?

DhairyaLGandhi commented 3 years ago

I do want to move some things over to zygote (things like the generator and dictionary + generator) but they can happen in parallel. What I want to do is add a couple more test cases that might have overlapping indices to make sure this is correct. We can commit the formatting whenever you want.

DhairyaLGandhi commented 3 years ago

I've added some more tests. I am unclear as to what to do with the round(Int, x) stuff peppered around only for testing against FiniteDifferences.jl. It makes the code dirtier, so I'd like to avoid it. Open to ideas.

codecov-commenter commented 3 years ago

Codecov Report

Merging #111 (d2ff77c) into main (fb0018d) will decrease coverage by 0.19%. The diff coverage is 82.05%.

:exclamation: Current head d2ff77c differs from pull request most recent head 72b4542. Consider uploading reports for the commit 72b4542 to get more accurate results Impacted file tree graph

@@            Coverage Diff             @@
##             main     #111      +/-   ##
==========================================
- Coverage   80.81%   80.62%   -0.20%     
==========================================
  Files          12       13       +1     
  Lines         318      351      +33     
==========================================
+ Hits          257      283      +26     
- Misses         61       68       +7     
Impacted Files Coverage Δ
src/utils/adjoints.jl 76.66% <76.66%> (ø)
src/utils/graph_building.jl 98.00% <100.00%> (+0.12%) :arrow_up:

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update fb0018d...72b4542. Read the comment docs.

DhairyaLGandhi commented 2 years ago

Seems like weights_cutoff works still, with a change in sign. So that's great. We'll need neighbours_list to AD as well, and then we would be able to do build_graph. neighbour_list touches a lot of Xtals.jl but I'll have to read into what the function is doing, and going from there.

DhairyaLGandhi commented 2 years ago

Spent a little time updating this locally - sitrep:


julia> c = Crystal(joinpath(cr_path, "IRMOF-1.cif"))
Name: /Users/dhairyagandhi/.julia/packages/Xtals/Kf4en/test/data/crystals/IRMOF-1.cif
Bravais unit cell of a crystal.
    Unit cell angles α = 90.000000 deg. β = 90.000000 deg. γ = 90.000000 deg.
    Unit cell dimensions a = 25.832000 Å. b = 25.832000 Å, c = 25.832000 Å
    Volume of unit cell: 17237.492730 ų

    # atoms = 424
    # charges = 0
    chemical formula: Dict(:Zn => 4, :H => 12, :O => 13, :C => 24)
    space Group: P1
    symmetry Operations:
        'x, y, z'

julia> gradient(c) do c
         i,j,d = neighbor_list(c)
         sum(d)
       end
((name = nothing, box = (a = 1938.7207119336626, b = 1861.6924099523428, c = 1947.283043589588, α = -416.8184235168784, β = 760.7150303623142, γ = -681.7548034041554, Ω = nothing, f_to_c = nothing, c_to_f = nothing, reciprocal_lattice = nothing), atoms = (n = nothing, species = nothing, coords = (xf = [1972.726096133977 -1972.7260961339766 … -441.49871345424566 -1138.6631802079703; -986.3630480669881 1972.7260961339766 … 1922.6874775652561 -1068.3497716550012; -986.3630480669881 -1972.726096133979 … 9.069665394507859e-14 -854.3377059102552],)), charges = nothing, bonds = nothing, symmetry = nothing),)

julia> gradient(c) do c
         w, s = build_graph(c)
         sum(w)
       end
((name = nothing, box = (a = NaN, b = NaN, c = NaN, α = NaN, β = NaN, γ = NaN, Ω = nothing, f_to_c = nothing, c_to_f = nothing, reciprocal_lattice = nothing), atoms = (n = nothing, species = nothing, coords = (xf = [NaN NaN … NaN NaN; NaN NaN … NaN NaN; NaN NaN … NaN NaN],)), charges = nothing, bonds = nothing, symmetry = nothing),)

Everything works but something becomes NaN at some point. Likely a divide by 0. But we should be able to do forces with this.

DhairyaLGandhi commented 2 years ago

Okay, so we have the following working.

julia> include("xtals.jl")
build_graph2 (generic function with 1 method)

julia> cr_path = normpath(joinpath(pathof(Xtals), "..", "..", "test", "data", "crystals"))
"/Users/dhairyagandhi/.julia/packages/Xtals/Kf4en/test/data/crystals"

julia> c = Crystal(joinpath(cr_path, "IRMOF-1.cif"))
Name: /Users/dhairyagandhi/.julia/packages/Xtals/Kf4en/test/data/crystals/IRMOF-1.cif
Bravais unit cell of a crystal.
    Unit cell angles α = 90.000000 deg. β = 90.000000 deg. γ = 90.000000 deg.
    Unit cell dimensions a = 25.832000 Å. b = 25.832000 Å, c = 25.832000 Å
    Volume of unit cell: 17237.492730 ų

    # atoms = 424
    # charges = 0
    chemical formula: Dict(:Zn => 4, :H => 12, :O => 13, :C => 24)
    space Group: P1
    symmetry Operations:
        'x, y, z'

julia> gradient(c) do c
         w, s = build_graph2(c)
         sum(w)
       end
((name = nothing, box = (a = -31.83244099108285, b = -31.832440991082862, c = -31.832440991082862, α = -2.1514963693844834e-14, β = -2.8035019022818323e-14, γ = -3.751262522266638e-14, Ω = nothing, f_to_c = nothing, c_to_f = nothing, reciprocal_lattice = nothing), atoms = (n = nothing, species = nothing, coords = (xf = [-4.78645247969116 4.786452479691137 … 51.735173803203864 51.735173803203864; 4.786452479691156 -4.786452479691122 … 77.23782826208793 -77.23782826208932; 4.786452479691155 4.7864524796911665 … -77.23782826208932 77.23782826208792],)), charges = nothing, bonds = nothing, symmetry = nothing),)

which are the gradients for the crystal and graph AD 🎉

Couple of notes:

  1. There were a few missing methods in Xtals to calculate distance, I added those https://github.com/SimonEnsemble/Xtals.jl/pull/101 (and also included a copy in the PR)
  2. replicate is written with a lot of mutation, and I figured we could write a simpler version, so I did. How do we want to handle this? We could propose this change to Xtals.jl (the performance is roughly same, but of course this allocates)
  3. I have refactored some of the indexing utilities and removed them from AD to save some time there. It means we can write it in a neater way if desired.
  4. How do we want to test this? Do we have some cases in mind?
  5. I might have messed up the rebase, so I'll have to move this to a different PR.
DhairyaLGandhi commented 2 years ago

Moving to #119