Open Dekermanjian opened 2 weeks ago
I don't suppose you can determine how many threads JAX is using here? Is there a simple repro I can try?
Hey @hawkinsp thank you for the very fast response. After tearing my hair out for 2 days I finally figured out what was happening. It turns out that podman limits the number of PIDs that can be created by a container. I was able to override the limit by adding: [containers] pids_limit=0
To $HOME/.config/containers/containers.conf
This fixes the problem.
Description
I am working on a linux server where I need to run a numpyro model in parallel. If I run my model directly on the server everything works fine. However, when I run inside a podman (rootless) container I get the following error message:
I am using the following ENV variables to try to limit jax's threading:
Since, it runs perfectly fine outside the container I am ruling out the following:
I don't know what else to try and help would be greatly appreciated.
System info (python version, jaxlib version, accelerator, etc.)