Open DavidHiggis opened 2 years ago
I'm trying to solve it too but on kaggle https://github.com/borisdayma/dalle-mini/issues/204, the replicate function is the culprit, I tried loading the mega fp16 model first then do the replicate (before loading the vqgan model) but still it OOMs.
If this code works, you could put it in a pull request and have it merged. I, for one, would love less VRAM usage, especially since my card is a gigabyte short of the twelve required to run mega_full.
Or just mindlessly copypasta?
I tried to solve the OOM problem and found out that those replicate(...) /flax.jax_utils.replicate are the major cause. That replicate thing literally doubled the ram usage and made mega-fp16 to crash on a 12gb colab vm.
Till now what I know is:
Codes above works. By the 'should be (?,256), not (1,?,256)' rule above, I remove the prefix above p_generate, and modify tokenize_prompt(self, prompt: str) to:
and generate_images to:
But no luck by now.
If the maker of this actually understand his codes (guess not, otherwise this line 'for i in range(max(num_predictions // jax.device_count(), 1))' should have been optimized away, since 'dalle_model.generate_images("warm-up", 1)' in app.py),
please tell us how to remove these pmap/batch thing which made for cluster VMs and NOT for standalone VM enviroment like free-tier Google Colab GPU.