JeffSHF / ColabDock

Code for ColabDock paper
Other
124 stars 14 forks source link

jaxlib.xla_extension.XlaRuntimeError for large systems #9

Open aravinda1879 opened 1 year ago

aravinda1879 commented 1 year ago

Hi, While the code works perfectly, for larger systems I am keep getting the following error. jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Failed to allocate request for 81.72GiB (87742154752B) on device ordinal 0 I do have access to ~300 Gb. I even tried adding following to the main.py to circumvent this and still the same error. os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".70" os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"

I tried using 'crop_len': 100 to 300, and still the same error. Any thoughts or solutions? Thanks!

JeffSHF commented 1 year ago

Hi, there. How many residues are there in your complex? Since you are using segment based optimization, I think the OOM error occurs during the AF2 inference.

aravinda1879 commented 1 year ago

Yes, you are correct. Following is the full error. My complex got 2200 residues. I was trying a small Ab with an Antigen to see its predictions. I could trim-down the Antigen a little more. But I wonder if there is an alternative for this? BTW, thanks for sharing the code for testing, it seems promising!

flat_sizes = jax.tree_flatten(in_sizes)[0]

0%| | 0/4 [00:00<?, ?it/s] 0%| | 0/4 [00:11<?, ?it/s] Traceback (most recent call last): File "/software/ColabDock/main.py", line 158, in dock_model.dock_rank() File "/software/ColabDock/colabdock/model.py", line 109, in dock_rank self.inference() File "/software/ColabDock/colabdock/docking.py", line 81, in inference af_model.gen_infer(save_path) File "/software/ColabDock/colabdesign/af/model.py", line 142, in gen_infer i_outputs = self._runner.apply(self._model_params[0], self.key(), self._inputs) File "/software/.conda/envs/mutGen/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback return fun(*args, *kwargs) File "/software/.conda/envs/mutGen/lib/python3.9/site-packages/jax/_src/api.py", line 620, in cache_miss execute = dispatch._xla_call_impllazy(fun, tracers, params) File "/software/.conda/envs/mutGen/lib/python3.9/site-packages/jax/_src/dispatch.py", line 236, in _xla_call_impl_lazy return xla_callable(fun, device, backend, name, donated_invars, keep_unused, File "/software/.conda/envs/mutGen/lib/python3.9/site-packages/jax/linear_util.py", line 300, in memoized_fun ans = call(fun, args) File "/software/.conda/envs/mutGen/lib/python3.9/site-packages/jax/_src/dispatch.py", line 359, in _xla_callable_uncached return lower_xla_callable(fun, device, backend, name, donated_invars, False, File "/software/.conda/envs/mutGen/lib/python3.9/site-packages/jax/_src/dispatch.py", line 996, in compile self._executable = XlaCompiledComputation.from_xla_computation( File "/software/.conda/envs/mutGen/lib/python3.9/site-packages/jax/_src/dispatch.py", line 1194, in from_xla_computation compiled = compile_or_get_cached(backend, xla_computation, options, File "/software/.conda/envs/mutGen/lib/python3.9/site-packages/jax/_src/dispatch.py", line 1077, in compile_or_get_cached return backend_compile(backend, serialized_computation, compile_options, File "/software/.conda/envs/mutGen/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(args, kwargs) File "/software/.conda/envs/mutGen/lib/python3.9/site-packages/jax/_src/dispatch.py", line 1012, in backend_compile return backend.compile(built_c, compile_options=options) jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Failed to allocate request for 19.02GiB (20422784512B) on device ordinal 0

JeffSHF commented 1 year ago

Currently, I think your proposed solution is the only way. Feel free to post comments if you have further questions.