Here I'm trying to initialize a very large dense layer. Despite the layer weights requiring only 32Gb of RAM (and I'm running on 80Gb H100), this code will fail because jax will try to simultaneously allocate quite a few buffers for RNG keys so that the total memory consumption is 112 Gb!
Is this intended? Do we really need to store these buffers in memory simultaneously to initialize the layer?
In any case, we can try to fix the problem by sharding the layer over the available devices (8x80Gb H100, comment line 66 and uncomment line 67 in the code above). Interestingly, while this change reduces the size of the parameter tensor as intended, rng buffers are still being allocated in full!
Description
Consider the following code snippet:
Here I'm trying to initialize a very large dense layer. Despite the layer weights requiring only 32Gb of RAM (and I'm running on 80Gb H100), this code will fail because jax will try to simultaneously allocate quite a few buffers for RNG keys so that the total memory consumption is 112 Gb!
Is this intended? Do we really need to store these buffers in memory simultaneously to initialize the layer?
In any case, we can try to fix the problem by sharding the layer over the available devices (8x80Gb H100, comment line 66 and uncomment line 67 in the code above). Interestingly, while this change reduces the size of the parameter tensor as intended, rng buffers are still being allocated in full!
This seems to be a bug: why is jax trying to materliaze the full rng tensor on each shard if it's not needed in full there?
Finally, if I use all zeros initialization (uncomment line 22 in the code above), the issue goes away.
So, to summarize, I have the following questions:
The example above, while artificial, is inspired by a real problem that we've encountered while trying to initialize a large model.
System info (python version, jaxlib version, accelerator, etc.)