larslorch / avici

Amortized Inference for Causal Structure Learning, NeurIPS 2022
https://arxiv.org/abs/2205.12934
MIT License
54 stars 6 forks source link

Running into OOM on 256 GB RAM #5

Open catchmoosa opened 6 days ago

catchmoosa commented 6 days ago

Traceback (most recent call last): File "/teamspace/studios/this_studio/mcgill_fiam/0X-Causal_discovery/discovery.py", line 30, in <module> g_prob = model(x=x) File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/avici/pretrain.py", line 109, in __call__ out = onp.array(out) File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/jax/_src/array.py", line 429, in __array__ return np.asarray(self._value, dtype=dtype, **kwds) File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/jax/_src/profiler.py", line 333, in wrapper return func(*args, **kwargs) File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/jax/_src/array.py", line 628, in _value self._npy_value = self._single_device_array_to_np_array() jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: Buffer Definition Event: Error preparing computation: %sOut of memory allocating 332034480032 bytes.

This is on 10,000 rows with 51 variables. Can you help me with this issue?

larslorch commented 5 days ago

It's quite possible that 10,000 rows is simply too large for the forward pass. One idea -- though I've never tried it -- could be to split the rows into smaller chunks and create a bootstrapped estimate of the graph by running several forward passes.

However, it seems that your error occurs here, after the forward pass is already done, can you confirm this? Maybe call jax.block_until_ready before this line to confirm, see here. In that case I don't currently know what could be the issue and would have to investigate. It would be great if you could provide a minimal example that reproduces this with random synthetic data