flatironinstitute / jax-finufft

JAX bindings to the Flatiron Institute Non-uniform Fast Fourier Transform (FINUFFT) library
Apache License 2.0
77 stars 2 forks source link

Question about thread safety #73

Closed dfm closed 4 months ago

dfm commented 6 months ago

When putting together #72, I discovered that setting the nthreads parameter leads to incorrect results, at least on my machine. Here's a minimal example:

import numpy as np
from jax_finufft import nufft1, options

opts1 = options.Opts(nthreads=1)
opts2 = options.Opts(nthreads=4)

M = 100000
N = 200000
random = np.random.default_rng(0)
x = 2 * np.pi * random.uniform(size=M)
c = random.standard_normal(size=M) + 1j * random.standard_normal(size=M)

f1 = nufft1(N, c, x, opts=opts1)
f2 = nufft1(N, c, x, opts=opts2)
np.testing.assert_allclose(f1, f2)

On my machine this sometimes fails with:

Not equal to tolerance rtol=1e-07, atol=0

Mismatched elements: 100001 / 200000 (50%)
Max absolute difference: 2.9662e+14
Max relative difference: 1.0000005
 x: array([-364.81818 +727.4102j  ,  296.5559  -183.3433j  ,
         88.81011 -329.78464j , ...,  -24.347906 -37.057587j,
       -166.946   +120.03134j ,  438.84576 -536.6191j  ], dtype=complex64)
 y: array([ 3.112425e+09-6.205858e+09j,  2.771406e+09-1.713399e+09j,
        2.949958e+08-1.095428e+09j, ...,  6.126120e+08+9.323974e+08j,
       -5.545357e+08+3.987017e+08j,  4.101148e+09-5.014870e+09j],
      dtype=complex64)

and some other times the values of f2 are just +/- inf.

@lgarrison — any thoughts?

lgarrison commented 6 months ago

I wonder if JAX is doing its own multi-threading that is not playing nice with finufft's. There's a note in the docs that one is supposed to use nthreads=1 when calling finufft in parallel (i.e. nested parallelism is not allowed): https://finufft.readthedocs.io/en/latest/cex.html#thread-safety-for-single-threaded-transforms-and-global-state . I'm not sure if that note is up to date, though, given the effort in https://github.com/flatironinstitute/finufft/pull/354 to make the FFTW calls thread-safe. And I'm not sure what else in finufft would not be thread safe.

dfm commented 6 months ago

Interesting! Over in https://github.com/flatironinstitute/finufft/pull/354 there is a comment that this won't work for running both the single and double precision libraries simultaneously (which we sort of do...) so perhaps that could be part of the problem? Either way, the problem doesn't seem to be a segfault, just invalid values, so I'm not sure how to debug!

In the short term perhaps we should just remove nthreads from the public Opts API. It would be interesting to get to the bottom of this though!

dfm commented 6 months ago

Also looping in @blackwer who might have some thoughts.

blackwer commented 6 months ago

Interesting! Over in flatironinstitute/finufft#354 there is a comment that this won't work for running both the single and double precision libraries simultaneously (which we sort of do...) so perhaps that could be part of the problem?

Single/double mode should be able to work in tandem -- that comment is about using the same mutex to lock both single and double critical sections. Running both at the same time should work, though there might be a (likely negligible) performance hit due to the shared lock.

There's a note in the docs that one is supposed to use nthreads=1 when calling finufft in parallel (i.e. nested parallelism is not allowed): https://finufft.readthedocs.io/en/latest/cex.html#thread-safety-for-single-threaded-transforms-and-global-state . I'm not sure if that note is up to date, though, given the effort in flatironinstitute/finufft#354 to make the FFTW calls thread-safe. And I'm not sure what else in finufft would not be thread safe.

This should be fixed, but in general we recommend avoiding nested parallelism. Regardless, it's very hard to test, especially with external libraries out of your control (FFTW here), so there very likely could still be an issue.

All that said... I'm unable to reproduce these issues with my setup. I get deterministic results.

Package           Version
----------------- --------------------
annotated-types   0.6.0
jax               0.4.25
jax-finufft       0.0.4.dev58+gf4c97f0
jaxlib            0.4.25
ml-dtypes         0.3.2
numpy             1.26.4
opt-einsum        3.3.0
pip               24.0
pydantic          2.6.4
pydantic_core     2.16.3
scipy             1.12.0
setuptools        65.5.0
typing_extensions 4.10.0
gcc: 11.4.0
fftw 3.3.10
python: 3.10.13

I did notice that during the build, cmake was trying to pull in my system fftw, rather than using a vendored version. I was able to get pip to install the vendored version via...

CC=gcc CXX=g++ pip install --force-reinstall --config-settings=cmake.define.FINUFFT_FFTW_LIBRARIES=DOWNLOAD ./jax-finufft

which also works as intended. Can you make sure this is reproducible with the vendored fftw 3.3.10? Also note that the all_close testing will not work when comparing output from finufft with different numbers of threads (I think there was even an issue in this repo with this before). See: https://github.com/flatironinstitute/finufft/issues/363 for details.

dfm commented 6 months ago

Thanks @blackwer! Some quick responses, with a fresh environment similar to yours

Local environment ``` # packages in environment at /opt/homebrew/Caskroom/miniforge/base/envs/jax-finufft-test: # # Name Version Build Channel annotated-types 0.6.0 pypi_0 pypi bzip2 1.0.8 h93a5062_5 conda-forge ca-certificates 2024.2.2 hf0a4a13_0 conda-forge jax 0.4.25 pypi_0 pypi jax-finufft 0.0.4.dev57+gef69daa.d20240318 pypi_0 pypi jaxlib 0.4.25 pypi_0 pypi libffi 3.4.2 h3422bc3_5 conda-forge libsqlite 3.45.2 h091b4b1_0 conda-forge libzlib 1.2.13 h53f4e23_5 conda-forge ml-dtypes 0.3.2 pypi_0 pypi ncurses 6.4 h463b476_2 conda-forge numpy 1.26.4 pypi_0 pypi openssl 3.2.1 h0d3ecfb_0 conda-forge opt-einsum 3.3.0 pypi_0 pypi pip 24.0 pyhd8ed1ab_0 conda-forge pydantic 2.6.4 pypi_0 pypi pydantic-core 2.16.3 pypi_0 pypi python 3.10.13 h2469fbe_1_cpython conda-forge readline 8.2 h92ec313_1 conda-forge scipy 1.12.0 pypi_0 pypi setuptools 69.2.0 pyhd8ed1ab_0 conda-forge tk 8.6.13 h5083fa2_1 conda-forge typing-extensions 4.10.0 pypi_0 pypi tzdata 2024a h0c530f3_0 conda-forge wheel 0.42.0 pyhd8ed1ab_0 conda-forge xz 5.2.6 h57fd34a_0 conda-forge ``` ``` fftw: 3.3.10 ```

I still get the same behavior. The problem here isn't really allclose (although that is an interesting issue too!), because I do still mostly get +/- inf as the result.

I can confirm that my system FFTW is being used and unfortunately your suggested installation command fails with an error saying that fftw3.h cannot be found.

It's worth noting that I'm testing on a Mac, and I know that parallelism can be a bit of a nightmare here. What system are you testing on? I'm fine with removing or adding a warning about the nthreads behavior on Mac, if that's the easiest approach!

lgarrison commented 6 months ago

For what it's worth, I was able to reproduce the issue on my Linux workstation. I didn't see infs, but I did see large errors.

On Mon, Mar 18, 2024, 10:19 AM Dan Foreman-Mackey @.***> wrote:

Thanks @blackwer https://github.com/blackwer! Some quick responses, with a fresh environment similar to yours Local environment

packages in environment at /opt/homebrew/Caskroom/miniforge/base/envs/jax-finufft-test:

#

Name Version Build Channel

annotated-types 0.6.0 pypi_0 pypi bzip2 1.0.8 h93a5062_5 conda-forge ca-certificates 2024.2.2 hf0a4a13_0 conda-forge jax 0.4.25 pypi_0 pypi jax-finufft 0.0.4.dev57+gef69daa.d20240318 pypi_0 pypi jaxlib 0.4.25 pypi_0 pypi libffi 3.4.2 h3422bc3_5 conda-forge libsqlite 3.45.2 h091b4b1_0 conda-forge libzlib 1.2.13 h53f4e23_5 conda-forge ml-dtypes 0.3.2 pypi_0 pypi ncurses 6.4 h463b476_2 conda-forge numpy 1.26.4 pypi_0 pypi openssl 3.2.1 h0d3ecfb_0 conda-forge opt-einsum 3.3.0 pypi_0 pypi pip 24.0 pyhd8ed1ab_0 conda-forge pydantic 2.6.4 pypi_0 pypi pydantic-core 2.16.3 pypi_0 pypi python 3.10.13 h2469fbe_1_cpython conda-forge readline 8.2 h92ec313_1 conda-forge scipy 1.12.0 pypi_0 pypi setuptools 69.2.0 pyhd8ed1ab_0 conda-forge tk 8.6.13 h5083fa2_1 conda-forge typing-extensions 4.10.0 pypi_0 pypi tzdata 2024a h0c530f3_0 conda-forge wheel 0.42.0 pyhd8ed1ab_0 conda-forge xz 5.2.6 h57fd34a_0 conda-forge

fftw: 3.3.10

I still get the same behavior. The problem here isn't really allclose (although that is an interesting issue too!), because I do still mostly get +/- inf as the result.

I can confirm that my system FFTW is being used and unfortunately your suggested installation command fails with an error saying that fftw3.h cannot be found.

It's worth noting that I'm testing on a Mac, and I know that parallelism can be a bit of a nightmare here. What system are you testing on? I'm fine with removing or adding a warning about the nthreads behavior on Mac, if that's the easiest approach!

— Reply to this email directly, view it on GitHub https://github.com/flatironinstitute/jax-finufft/issues/73#issuecomment-2004045879, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABLA7S4BMXDWBGBGLGHEGM3YY3ZVRAVCNFSM6AAAAABEYS6XJKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDAMBUGA2DKOBXHE . You are receiving this because you were mentioned.Message ID: @.***>

blackwer commented 6 months ago

I didn't see infs, but I did see large errors.

How big? Absolute error was ~5, and rel errors ~5E-3 here. I didn't check what the default requested accuracy was, but I assume it was ~1E-3.

I can confirm that my system FFTW is being used and unfortunately your suggested installation command fails with an error saying that fftw3.h cannot be found.

I haven't tried on mac. I assume you're using a brew compiler if you're going for threading? Macs + threading is indeed thorny... we have binary wheels for finufft 2.2 now -- is there a reason you can't use those? I'll see if I can replicate this on the intel mac -- should I try the M1 as well?

lgarrison commented 6 months ago

I tested this on my way out the door last Friday, so I'll have to check again when I'm back late tomorrow or Wednesday. I'll also need to check again whether I was using the CPU or GPU backend...

On Mon, Mar 18, 2024, 10:35 AM Robert Blackwell @.***> wrote:

I didn't see infs, but I did see large errors.

How big? Absolute error was ~5, and rel errors ~5E-3 here. I didn't check what the default requested accuracy was, but I assume it was ~1E-3.

I can confirm that my system FFTW is being used and unfortunately your suggested installation command fails with an error saying that fftw3.h cannot be found.

I haven't tried on mac. I assume you're using a brew compiler if you're going for threading? Macs + threading is indeed thorny... we have binary wheels for finufft 2.2 now -- is there a reason you can't use those? I'll see if I can replicate this on the intel mac -- should I try the M1 as well?

— Reply to this email directly, view it on GitHub https://github.com/flatironinstitute/jax-finufft/issues/73#issuecomment-2004085842, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABLA7S3QMZRWLGUWCJMGODDYY33RXAVCNFSM6AAAAABEYS6XJKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDAMBUGA4DKOBUGI . You are receiving this because you were mentioned.Message ID: @.***>

dfm commented 6 months ago

we have binary wheels for finufft 2.2 now -- is there a reason you can't use those?

Yeah, this library actually needs to compile its own wheels because JAX calls FINUFFT via C++/CUDA, rather than via Python. And, let me emphasize that I don't want to put pressure on you to solve this! This is low priority, and more of a question of interest. Our support of the guru interface in this package is experimental at best, and the package generally shouldn't be thought of as a core part of the FINUFFT ecosystem.

I'll see if I can replicate this on the intel mac -- should I try the M1 as well?

I'm using an M1. I'd love to hear what you find, but again no stress! Thanks for all your feedback so far. When I get a chance, I'll do some digging into the build system as well to see if I can track anything down.

blackwer commented 6 months ago

I'm able to reproduce this on my intel mac, and I have yet to find a magic incantation that will resolve it. For now, I recommend not using threads on mac. It's likely this bug is in the upstream wheels as well, so I might have to open an issue there. I'll wait until Lehman confirms the order of magnitude of error on his linux machine before I dig deeper.

lgarrison commented 5 months ago

I find similar magnitude errors to Robert:

Mismatched elements: 100001 / 200000 (50%)
Max absolute difference: 6.097244
Max relative difference: 0.00436589

The default requested tolerance is 1e-6, but the fact that the two methods aren't agreeing that well could just be a condition number issue. That is, small differences in how the two NUFFTs are calculated with single versus multiple threads that cause, e.g., differences of 6e-8 at some early stage in the calculation could be amplified by the condition number, which is about 2e5 (the length of the transform). So fractional differences less than 2e-5 * 6e-8 = 1e-2 are probably supposed to be accepted. Alex recently added some notes about the condition number to the docs.

A couple other features of this disagreement make me thing it's a condition number issue: the fractional error scales approximately with the transform length, and going to 64-bit precision resolves the issue.

The fact that exactly half (plus 1) of the elements disagree is certainly suspicious, but whatever algorithmic feature that's causing the single-threaded and multi-threaded computations to diverge could just be happening at the midpoint of the output.

dfm commented 5 months ago

Thanks for this summary, @lgarrison! It's interesting that the mac version has the same number of mismatches even if the order of magnitude is so dramatically different. With this in mind, I'm happy to disallow/discourage the use of threading on mac. It might be worth thinking about if there is a compatibility check we could do programmatically, but that might require tracking down the actual source of the problem, and I'm not sure I'm up to that task!

lgarrison commented 5 months ago

It's interesting that the mac version has the same number of mismatches even if the order of magnitude is so dramatically different.

Definitely curious, maybe it points at a common underlying cause. I don't have a Mac to test on, but if the issue also appears in non-JAX finufft, we should probably open an upstream issue.

In the meantime, I agree that disabling threading on Mac is a good workaround.

lu1and10 commented 4 months ago

Thanks for this summary, @lgarrison! It's interesting that the mac version has the same number of mismatches even if the order of magnitude is so dramatically different. With this in mind, I'm happy to disallow/discourage the use of threading on mac. It might be worth thinking about if there is a compatibility check we could do programmatically, but that might require tracking down the actual source of the problem, and I'm not sure I'm up to that task!

@dfm It seems that on mac, if finufft cpu lib is built with omp off, I see the inf/nan. Could you try to enable omp when installing jax_finufft?

I did the following tweaks to install jax_finufft on my mac, in the jax_finufft root CMakeLists.txt:

--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -32,7 +32,7 @@ endif()

 if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
     # TODO(dfm): OpenMP segfaults on my system - can we enable this somehow?
-    set(FINUFFT_USE_OPENMP OFF)
+    set(FINUFFT_USE_OPENMP ON)
 else()
     set(FINUFFT_USE_OPENMP ON)
 endif()

I'm on intel mac, linked fftw with omp suffix works, not sure on m1 should link fftw with threads suffix or not.

In the vendor/finufft CMakeLists.txt with AppleClang(I assumed brew installed fftw, libomp), I need to tweak a bit to find libomp, find_package(OpenMP REQUIRED) did not find libomp on my mac somehow:

--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -66,7 +66,15 @@ if (FINUFFT_BUILD_MATLAB)
 else ()
     # For non-matlab builds, find system OpenMP
     if (FINUFFT_USE_OPENMP)
-        find_package(OpenMP REQUIRED)
+        # AppleClang with brew installed libomp
+        if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")
+            find_library(brew_omp_lib NAMES omp HINTS /usr/local/opt/libomp/lib)
+            add_library(OpenMP::OpenMP_CXX SHARED IMPORTED)
+            set_target_properties(OpenMP::OpenMP_CXX PROPERTIES IMPORTED_LOCATION ${brew_omp_lib})
+            target_compile_options(OpenMP::OpenMP_CXX INTERFACE -Xclang -fopenmp)
+        else()
+            find_package(OpenMP REQUIRED)
+        endif()
     endif ()
 endif ()

@@ -123,6 +131,9 @@ function(set_finufft_options target)

     target_include_directories(${target} PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include")
     if (FINUFFT_USE_OPENMP)
+        if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")
+            target_include_directories(${target} PUBLIC "/usr/local/opt/libomp/include")
+        endif()
         target_link_libraries(${target} PRIVATE OpenMP::OpenMP_CXX)
         # there are issues on windows with OpenMP and CMake, so we need to manually add the flags
         # otherwise there are link errors

With export JAX_ENABLE_X64=1 the allclose passes, with export JAX_ENABLE_X64=0 I got similar error as @lgarrison got on linux as the condition number is large.

I haven't tried llvm clang on my mac with jax_finufft installation, but it seems llvm clang on mac works with the finufft github actions.

dfm commented 4 months ago

@lu1and10 — Thanks for this! I played around with this some more, and it does seem like your PR to FINUFFT (https://github.com/flatironinstitute/finufft/pull/431) does handle this issue gracefully.

There is a bigger question about how to best expose the necessary compiler flags for linking to the appropriate OpenMP, but that's probably more of an upstream question. Perhaps FINUFFT itself should have built in support for discovering OpenMP installed with brew. On my Mac, I was able to get everything working using conda (these were the required packages: llvm-openmp fftw cxx-compiler).

lu1and10 commented 4 months ago

@lu1and10 — Thanks for this! I played around with this some more, and it does seem like your PR to FINUFFT (flatironinstitute/finufft#431) does handle this issue gracefully.

There is a bigger question about how to best expose the necessary compiler flags for linking to the appropriate OpenMP, but that's probably more of an upstream question. Perhaps FINUFFT itself should have built in support for discovering OpenMP installed with brew. On my Mac, I was able to get everything working using conda (these were the required packages: llvm-openmp fftw cxx-compiler).

Yes, I don't know how to link appropriate OpenMP without the tweaking the cmake. I was hoping cmake's find_package(OpenMP REQUIRED) should handle all cases on different Oses, but it seems to have trouble with find omp..