JuliaDiff / ReverseDiff.jl

Reverse Mode Automatic Differentiation for Julia
Other
348 stars 57 forks source link

double free crash with multi-threaded code only when using multiple threads #218

Open cgeoga opened 1 year ago

cgeoga commented 1 year ago

I apologize if my searching through past issues wasn't good enough and I've missed some existing discussion on this. But I have some multi-threaded code that crashes when I try to build a GradientTape. It runs fine when I start julia with only one thread, but if I start with multiple threads these crashes are consistent. Here is an MWE:

using LinearAlgebra, Vecchia, StaticArrays, ReverseDiff

Threads.nthreads() > 1 && BLAS.set_num_threads(1) 

# create cfg, which to be very clear is just random data and is not supposed
# to be something where optimizing the induced likelihood approximation gives
# sensible answers.
demo_matern12(x, y, p) = (p[1]^2)*exp(-norm(x-y)/p[2])
const cfg = Vecchia.kdtreeconfig(randn(1000), rand(SVector{2,Float64}, 1000), 
                                 5, 3, demo_matern12)

# Two different functions that are pretty careful with allocations and are
# designed to do lots of small linear algebra operations with many threads.
smallform_nll(p) = Vecchia.nll(cfg, p)
rcholform_nll(p) = Vecchia.nll_rchol(cfg, p; issue_warning=false)

# If you uncomment either declaration line and run this script, you get:
#
# -- one thread: everything runs, and if you compile the tapes and evaluate you get the right answers.
# -- >1 threads (smallform): double free or corruption (!prev) crash.
# -- >1 threads (rcholform): double free or corruption (out)   crash.
#
#smallform_tape = ReverseDiff.GradientTape(smallform_nll, ones(2))
#rcholform_tape = ReverseDiff.GradientTape(rcholform_nll, ones(2))

Looking at the stacktraces (?) of the crashes is a bit opaque to me---I can find lines in the stacktrace that point to lines in my source for Vecchia.jl, but when I go look at them they are comments, or the stacktrace says something like macro expansion at [...] and refers to a line that does not have a macro in it. So clearly interpreting those requires some understanding that I don't have.

I recognize that a natural first thought would be that there is something bad that I am doing with the threading. And I certainly can't rule it out. Here is my threading pattern, and the comment above links to where I got it from. But I will say, I can use ForwardDiff with multiple threads and everything is great, and I've been heavily using this code for a pretty long time with ForwardDiff and have never experienced anything that was incorrect or made me raise my eyebrows (or a crash). But that doesn't mean that something isn't wrong, of course.

Anyways, I'm wondering if somebody here could help me try to sort out what's going on. Any help or thoughts would be appreciated! And if there is any other information I can provide to make giving that help easier please let me know.

Here is my specific Project.toml for those who are willing to run this script and tinker:

[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Vecchia = "8d73829f-f4b0-474a-9580-cecc8e084068"