google-deepmind / functa

Apache License 2.0
143 stars 7 forks source link

Uneven GPU usage #11

Closed Separius closed 1 year ago

Separius commented 1 year ago

Hi, Thank you for providing your code. I'm super new to jax and I'm not sure if this is a common thing to jax or your codebase, but I'm seeing a strange GPU ram usage on my GPUs, I have a single node with 4 GPUs (each with 24GB of RAM). The first GPU is using 23895MiB (according to nvidia-smi), and the rest are using 1487MiB. Is that expected or something is wrong with my environment? (BTW, I also sat XLA_PYTHON_CLIENT_PREALLOCATE to false and it didn't make a difference)

hyunjik11 commented 1 year ago

Hi, it looks like you're only using one of the four GPUs. Could you check whether jax.local_device_count() correctly returns a value of 4, or whether it gives 1? If it's 4, then the pmap in experiment_meta_learning.py should be using all 4 devices equally, so I'm not sure why you'd be seeing such asymmetric behaviour. But it it's 1 then this would be an issue with jax and not with the functa codebase. This link might help you: https://github.com/google/jax/issues/5231