m3g / CellListMap.jl

Flexible implementation of cell lists to map the calculations of particle-pair dependent functions, such as forces, energies, neighbor lists, etc.
https://m3g.github.io/CellListMap.jl/
MIT License
87 stars 4 forks source link

ReverseDiff gradients #93

Open dforero0896 opened 1 year ago

dforero0896 commented 1 year ago

Hi, I wanted to try getting some gradients from a function involving map_pairwise as I saw on the docs that automatic differentiation was available. The issue is my input consists in hundreds of thousands of variables (a 3x~1e5 Matrix) and my output is a loss score, so ForwardDiff is quite inefficient. I tried just replacing it with ReverseDiff but I got this

ERROR: LoadError: ArgumentError: cannot reinterpret `ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}}` as
 `SVector{3, ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}}}`, type `SVector{3, ReverseDiff.TrackedReal
{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}}}` is not a bits type 

and with Zygote

ERROR: LoadError: Compiling Tuple{CellListMap.var"##set_number_of_batches!#38", Bool, typeof(CellListMap.set_number_of_batches!), CellList{3, Float64}, Tuple{Int64, Int64}}: try/c
atch is not supported.

My question would be first, if it is even possible to use reverse-mode differentiation with CellListMap? If so, is it possible to add some examples to the docs on how to do so? The type-conversion trick used for ForwardDiff does not work. Thanks in advance for your help.

lmiq commented 1 year ago

I don't remember having tried reverse diff. I'll take a look at it asap.

lmiq commented 1 year ago

Here is the situation. Currently, you cannot really reverse-differentiate easily through the whole construction of the cell lists, but you can bypass all that and differentiate the computation of the objective function, if the coordinates are provided (redundantly) as a closure.

For example, consider the following simple function, which sums the squared distance between particle coordinates:

sum_sqr(d2, s) = s += d2

Which could be mapped to all pairs of particles (here constructed as matrices of size (3,N), with:

coordinates = rand(3,1000)
box = Box([1,1,1], 0.05)
cl = CellList(coordinates,box)
map_pairwise( (_, _, _, _, d2, s) -> sum_sqr(d2, s), 0.0, box, cl)

(the (x,y,i,j...) parameters are omitted because they are not used in the sum_sqr function).

This can be forward-differentiated as shown in the manual, but reverse differentiation does not work, basically because the construction of the cell lists requires mutation of arrays, and the current infrastructures do not support that (maybe Enzyme could do it, but there is a simpler alternative).

The trick is to define a function that uses only the indexes of the particles, and compute the property of interest, from the particles, using the coordinates provided in a closure. That is, the above function would be defined as:

sum_sqr(i, j, s, coordinates) = s += sum(abs2, @views(coordinates[:,i] - coordinates[:,j]))

Note that x here is the complete set of coordinates, which now will be closed over in the function call to map_parwise, and not correspond to any of the inner input parameters. That is:

coordinates = rand(3,1000)
box = Box([1,1,1], 0.05)
cl = CellList(coordinates, box)
map_pairwise( (_, _, i, j, _, s) -> sum_sqr(i, j, s, coordinates), 0.0, box, cl)

Now we use the i and j internal parameters, but we close over the coordinates. The call to map_pairwise can now be differentiated with respect to the coordinates, because the construction of the cell lists is not part of the equation. For that, we enclose the complete call into a function that receives box and cl as parameters:

julia> function sum_sqr(coordinates, box, cl)
           sum_sqr = map_pairwise!(
               (_, _, i, j, _, sum_sqr) -> sum_sqr += sum(abs2, @views(coordinates[:,i] - coordinates[:,j])),
               zero(eltype(coordinates)), box, cl,
           )
           return sum_sqr
       end
sum_sqr (generic function with 3 methods)

And this can be both forward and reverse- differentiated:

julia> using ForwardDiff, ReverseDiff

julia> coordinates = rand(3,1000);

julia> box = Box([1,1,1], 0.05);

julia> cl = CellList(coordinates, box);

julia> gr = ReverseDiff.gradient( (x) -> sum_sqr(x,box,cl), coordinates)
3×1000 Matrix{Float64}:
 -0.0875518    0.0  0.0  -0.0848635  0.0  0.0  0.0  …  0.0  0.0   0.00620111  -0.0335964  0.0   0.0234671
 -0.0442765    0.0  0.0  -0.0575914  0.0  0.0  0.0     0.0  0.0  -0.014573     0.0649681  0.0   0.0607623
  0.000192838  0.0  0.0  -0.0623673  0.0  0.0  0.0     0.0  0.0   0.0316766    0.0451708  0.0  -0.00635728

julia> gr = ForwardDiff.gradient( (x) -> sum_sqr(x,box,cl), coordinates)
3×1000 Matrix{Float64}:
 -0.0875518    0.0  0.0  -0.0848635  0.0  0.0  0.0  …  0.0  0.0   0.00620111  -0.0335964  0.0   0.0234671
 -0.0442765    0.0  0.0  -0.0575914  0.0  0.0  0.0     0.0  0.0  -0.014573     0.0649681  0.0   0.0607623
  0.000192838  0.0  0.0  -0.0623673  0.0  0.0  0.0     0.0  0.0   0.0316766    0.0451708  0.0  -0.00635728

As expected, reverse differentiation is much faster here:


julia> revg(coordiantex, box, cl) = ReverseDiff.gradient( (x) -> sum_sqr(x,box,cl), coordinates)
revg (generic function with 1 method)

julia> forg(coordinates, box, cl) = ForwardDiff.gradient( (x) -> sum_sqr(x,box,cl), coordinates)
forg (generic function with 1 method)

julia> @btime revg($coordinates, $box, $cl);
  550.999 μs (11188 allocations: 567.33 KiB)

julia> @btime forg($coordinates, $box, $cl);
  60.827 ms (73754 allocations: 24.44 MiB)
dforero0896 commented 1 year ago

Hi, thanks a lot for your answer! I've been able to replicate your example for some similar data. But sometimes I get ERROR: LoadError: UndefRefError: access to undefined reference and others it Segfaults and quits Julia.

...
main at julia (unknown line)
__libc_start_main at /lib64/libc.so.6 (unknown line)
unknown function (ip: 0x401098)
Allocations: 282841011 (Pool: 282818657; Big: 22354); GC: 103
Segmentation fault (core dumped)

My particular application is a histogram (or two-point function) that I would like to differentiate through but I keep getting the Segfault. Could it be because of the size of my data (600k coordinates)? Though that would still not explain why it segfaults with 1k coordinates.

lmiq commented 1 year ago

Segfaults usually are related to some corrupted memory access, and not because of the size of the data. When the data is too big to fit in memory you get OutOfMemory errors.

Without further details, I can't speculate on what may be going on there.

One thing that may be related is that when running CellLIstMap in parallel, there is a machinery to avoid concurrency among threads, which I don't know if the differentiation routines can handle that properly. One test is to run the calculations without parallelization.

I've made a small test here, and the results in that simple example are the same. But in your case you are probably updating a shared histogram array as the output, and the example above is a scalar function, so not exactly the same thing:

julia> function sum_sqr(coordinates, box, cl; parallel=true)
           sum_sqr = map_pairwise!(
               (_, _, i, j, _, sum_sqr) -> sum_sqr += sum(abs2, @views(coordinates[:,i] - coordinates[:,j])),
               zero(eltype(coordinates)), box, cl; parallel=parallel
           )
           return sum_sqr
       end
sum_sqr (generic function with 1 method)

julia> coordinates = rand(3,5000);

julia> box = Box([1,1,1], 0.05);

julia> cl = CellList(coordinates, box);

julia> ReverseDiff.gradient(x -> sum_sqr(x, box, cl; parallel=true), coordinates)
3×5000 Matrix{Float64}:
 -0.0344465  0.0  -0.00210207  -0.0551787  -0.0152113   0.0472379   …   0.0680135  -0.0428575   0.0   0.0715126  -2.16583
  0.0527802  0.0  -0.0280015   -0.0101965  -0.0427886  -0.0660361       0.10064     0.00204247  0.0  -0.0447712   0.0216365
  0.0416164  0.0   0.103765    -0.0231282  -0.0654189  -0.00256454     -0.0578568  -3.82477     0.0   0.10979     0.0366878

julia> ReverseDiff.gradient(x -> sum_sqr(x, box, cl; parallel=false), coordinates)
3×5000 Matrix{Float64}:
 -0.0344465  0.0  -0.00210207  -0.0551787  -0.0152113   0.0472379   …   0.0680135  -0.0428575   0.0   0.0715126  -2.16583
  0.0527802  0.0  -0.0280015   -0.0101965  -0.0427886  -0.0660361       0.10064     0.00204247  0.0  -0.0447712   0.0216365
  0.0416164  0.0   0.103765    -0.0231282  -0.0654189  -0.00256454     -0.0578568  -3.82477     0.0   0.10979     0.0366878

julia> ReverseDiff.gradient(x -> sum_sqr(x, box, cl; parallel=false), coordinates) ≈
          ReverseDiff.gradient(x -> sum_sqr(x, box, cl; parallel=true), coordinates)
true
lmiq commented 1 year ago

In fact, in a simple Histogram-like function, ReverseDiff fails, even serially:

julia> function hist(coordinates, box, cl; parallel=true)
           h = map_pairwise!(
               (_, _, i, j, d2, h) -> begin
                   if sqrt(d2) < box.cutoff / 2 
                       h[1] += sum(abs2, @views(coordinates[:,i] - coordinates[:,j]))
                   else
                       h[2] += sum(abs2, @views(coordinates[:,i] - coordinates[:,j]))
                   end
                   return h
               end,
               zeros(eltype(coordinates), 2), box, cl; parallel=parallel
           )
           return h
       end
hist (generic function with 1 method)

julia> hist(coordinates, box, cl)
2-element Vector{Float64}:
  406.8648388535888
 3223.091284798098

julia> ReverseDiff.gradient(x -> hist(x, box, cl; parallel=false), coordinates)
ERROR: DimensionMismatch: new dimensions (2, 10000) must be consistent with array size 10000

If you change that to compute the histogram by passing the hist variable in the closure, than you certainly can get corrupted memory accesses among threads. And I actually couldn't make it work.

Maybe one alternative is to compute each bin of the histogram independently, as that would provide scalar returns to the function. I'm not an specialist in autodiff to be more precise about what to suggest there.

dforero0896 commented 1 year ago

I see, however it seems you get a different kind of error as I do (you get DimensionMismatch and I get UndefRefError). My complete function actually computes a scalar in the end since it computes the MSE to a reference 2pt function, but I guess the same undefined access applies.

lmiq commented 1 year ago

Can you share something about your code, such to at least we can localize the issue?

dforero0896 commented 1 year ago

Sure! here are my "core" functions. I believe the issue can be reproduced with randomly distributed particles in a box since my dataset is a bit large (though the "before" dataset I shared in a past issue may work too):

bin_edges = 10 .^range(-2, stop = log10(50), length=11)
positions = 2000. .* rand(3, 600000)
box_size = [2e3 for _ = 1:3]
box = Box(box_size, 5.)
cl = CellList(positions, box)
function coordinate_separation(a, b, box_size)
    delta = abs(a - b)
    return (delta > 0.5*box_size ? delta - box_size : delta)*sign(a-b)
end
function diff_build_histogram!(i, j ,hist, coordinates, bin_edges, box_size)
    d2 = sum(abs2, coordinate_separation.(view(coordinates, : , i), view(coordinates, :, j), box_size))
    ibin = searchsortedlast(bin_edges, sqrt(d2))
    if (ibin > 0) && ibin <= length(bin_edges)
        hist[ibin] += 1
    end #if
    return hist
end
function loss(positions, box, cl, bin_edges, box_size)
    hist = zeros(Int,size(bin_edges,1)-1);
    println("Counting pairs...")
    # Run calculation
    map_pairwise!(
        (_, _, i, j, _, hist) -> diff_build_histogram!(i, j, hist, positions, bin_edges, box_size),
        hist, box, cl; show_progress = true
    )
    println("Done")
    N = size(positions,2)
    hist = hist  / (N * (N - 1))
    norm = @. (4/3) * π * (bin_edges[2:end]^3 -bin_edges[1:end-1]^3) / (box_size[1] * box_size[2] * box_size[3])
    hist ./= norm
    mean(abs.(hist - xi_ref)) # I think for testing purposes xi_ref can be 0.
end #func
ReverseDiff.gradient((x) -> loss(x, box, cl, bin_edges, box_size), positions)
dforero0896 commented 1 year ago

Sorry, actually in my last test it seems to work (I did deactivate parallelization). All gradients seem to be are 0 but that may be because the histogram is just not differentiable.

lmiq commented 1 year ago

I would try to compute a single count (of one bin) in a regular scalar variable to see how that works.

Then, if that works, maybe it is possible to create the histogram with an immutable structure (a Svector, for example).

lmiq commented 1 year ago

Just to add, if I compute a single bin of the histogram, the differentiation apparently works, but returns, all zeros, as you observed. I'm not sure if this is correct:

julia> using CellListMap, LinearAlgebra

julia> function hist(coordinates, box, cl; parallel=true)
           h = map_pairwise!(
               (_, _, i, j, _, h) -> begin
                   d = norm(@views(coordinates[:,i] - coordinates[:,j]))
                   if d < box.cutoff / 2 
                       h += 1
                   #else
                   #    h[2] += 1
                   end
                   return h
               end,
               0, box, cl; parallel=parallel
           )
           return h
       end
hist (generic function with 1 method)

julia> coordinates = rand(3,1000);

julia> box = Box([1,1,1], 0.05);

julia> cl = CellList(coordinates, box);

julia> hist(coordinates, box, cl)
24

julia> all(==(0), ReverseDiff.gradient(x -> hist(x, box, cl; parallel=false), coordinates))
true

julia> all(==(0), ForwardDiff.gradient(x -> hist(x, box, cl; parallel=false), coordinates))
true

(example fixed @dforero0896)

dforero0896 commented 1 year ago

Indeed it seems to work. The issue of the zeros is just that this "exact" way of histogramming is not differentiable. An approximate histogram could be built with something like

function diff_build_histogram!(i, j ,hist, coordinates, bin_widths, box_size, bin_centers)
    d2 = sum(abs2, coordinate_separation.(view(coordinates, : , i), view(coordinates, :, j), box_size))
    hist .+= exp(-((sqrt(d2) .- bin_centers) ./ bin_widths).^2)
    return hist
end

So it is clear how the end product depends on the coordinates. Thanks for your help!

lmiq commented 1 year ago

Yes, cool, I was thinking about that problem. Exactly, the histogram has a zero gradient because no infinitesimal move of of the particles will cause a particle to change from one bin to the other. No only the derivative is discontinuous, but mostly it is zero.

The problem of obtaining a differentiable distribution is indeed interesting. Thanks for posting.

I will update the docs with some examples that came out of this discussion, and will close the issue when I do that, thank you very much for the feedback. It will be useful for others to know to apply ReverseDiff here.

dforero0896 commented 12 months ago

Glad my question was helpful. There are some other packages that have implemented differentiable histogramming in other ways. May be useful for someone looking into this too.