yang-song / score_inverse_problems

Official repo for "Solving Inverse Problems in Medical Imaging with Score-Based Generative Models"
227 stars 25 forks source link

Possibility to run on CPU #4

Open zaccharieramzi opened 2 years ago

zaccharieramzi commented 2 years ago

Hi,

I have tried running the code on CPU (my setup is with Python 3.9, Ubuntu 16.04 on an 8-core machine), and I have had a segmentation fault:

Fatal Python error: Segmentation fault

Thread 0x00007f56467fc700 (most recent call first):
  File "/usr/lib/python3.9/threading.py", line 316 in wait
  File "/usr/lib/python3.9/threading.py", line 574 in wait
  File "/home/zaccharie/workspace/score_inverse_problems/venv/lib/python3.9/site-packages/tqdm/_monitor.py", line 60 in run
  File "/usr/lib/python3.9/threading.py", line 954 in _bootstrap_inner
  File "/usr/lib/python3.9/threading.py", line 912 in _bootstrap

Current thread 0x00007f5813035700 (most recent call first):
  File "/home/zaccharie/workspace/score_inverse_problems/venv/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 1160 in execute_replicated
  File "/home/zaccharie/workspace/score_inverse_problems/venv/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 637 in xla_pmap_impl
  File "/home/zaccharie/workspace/score_inverse_problems/venv/lib/python3.9/site-packages/jax/core.py", line 607 in process_call
  File "/home/zaccharie/workspace/score_inverse_problems/venv/lib/python3.9/site-packages/jax/core.py", line 1624 in process
  File "/home/zaccharie/workspace/score_inverse_problems/venv/lib/python3.9/site-packages/jax/core.py", line 1552 in call_bind
  File "/home/zaccharie/workspace/score_inverse_problems/venv/lib/python3.9/site-packages/jax/core.py", line 1621 in bind
  File "/home/zaccharie/workspace/score_inverse_problems/venv/lib/python3.9/site-packages/jax/_src/api.py", line 1632 in f_pmapped
  File "/home/zaccharie/workspace/score_inverse_problems/venv/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 183 in reraise_with_filtered_traceback
  File "/home/zaccharie/workspace/score_inverse_problems/score_inverse_problems/run_lib.py", line 391 in evaluate
  File "/home/zaccharie/workspace/score_inverse_problems/score_inverse_problems/main.py", line 60 in main
  File "/home/zaccharie/workspace/score_inverse_problems/venv/lib/python3.9/site-packages/absl/app.py", line 251 in _run_main
  File "/home/zaccharie/workspace/score_inverse_problems/venv/lib/python3.9/site-packages/absl/app.py", line 303 in run
  File "/home/zaccharie/workspace/score_inverse_problems/score_inverse_problems/main.py", line 68 in <module>
[1]    5511 segmentation fault (core dumped)  python score_inverse_problems/main.py --config  --workdir=./ --mode eval

Have you tried running the code on CPU, or is it a GPU-only code?

tianzhijiaoziA commented 2 years ago

hi,What is your configuration version? Is it based on the information given by the author,I just used python3.6 and the configuration given by the author. Jax error will appear c6b5ea22ecc17b0636f995b98a95cd6

tianzhijiaoziA commented 2 years ago

hi, I am a student of sysu, GPU cannot be used in this jax version, it is better to use tpu, and it is better to use video memory >=48G after testing, A100, jaxlib1.69-1.73 is better, the first time I tried jax framework The problem has been troubled for a long time, I hope it can help you

dsmagiya commented 2 years ago

Hi @tianzhijiaoziA , I am having same issues for weeks. Could you please share some more details on running it on TPU of A100? Currently I am running Python3.6 on the cluster which only supports up to JAX0.2.17, which sends these errors:

WARNING:tensorflow:From /home/fs01/dm852/venv/test/lib64/python3.6/site-packages/tensorflow_gan/python/estimator/tpu_gan_estimator.py:42: The name tf.estimator.tpu.TPUEstimator is deprecated. Please use tf.compat.v1.estimator.tpu.TPUEstimator instead.

I1004 19:30:43.270267 22970125400960 tpu_client.py:54] Starting the local TPU driver. I1004 19:30:43.271347 22970125400960 xla_bridge.py:231] Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local:// I1004 19:30:43.271621 22970125400960 xla_bridge.py:231] Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host I1004 19:30:43.271768 22970125400960 xla_bridge.py:231] Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available. W1004 19:30:43.271850 22970125400960 xla_bridge.py:234] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) /home/fs01/dm852/venv/test/lib64/python3.6/site-packages/jax/lib/xla_bridge.py:374: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code. "jax.host_id has been renamed to jax.process_index. This alias " I1004 19:31:40.582314 22970125400960 checkpoints.py:223] Found no checkpoint files in /home/fs01/dm852/venv/pt_tomography/score_inverse_problems-main/checkpoints-meta I1004 19:31:40.584645 22970125400960 dataset_info.py:365] Load dataset info from /home/fs01/dm852/tensorflow_datasets/pt3701_512/1.0.0 I1004 19:31:40.587075 22970125400960 dataset_builder.py:351] Reusing dataset pt3701_512 (/home/fs01/dm852/tensorflow_datasets/pt3701_512/1.0.0) I1004 19:31:40.587217 22970125400960 logging_logger.py:34] Constructing tf.data.Dataset pt3701_512 for split train[:80%], from /home/fs01/dm852/tensorflow_datasets/pt3701_512/1.0.0 I1004 19:31:40.803931 22970125400960 dataset_builder.py:351] Reusing dataset pt3701_512 (/home/fs01/dm852/tensorflow_datasets/pt3701_512/1.0.0) I1004 19:31:40.804249 22970125400960 logging_logger.py:34] Constructing tf.data.Dataset pt3701_512 for split train[80%:90%], from /home/fs01/dm852/tensorflow_datasets/pt3701_512/1.0.0 I1004 19:31:41.548434 22970125400960 run_lib.py:149] Starting training loop at step 0.