Closed danielchalef closed 2 years ago
The model loads after modifying how JAX allocates memory. See here: https://github.com/borisdayma/dalle-mini/issues/185#issuecomment-1145961653
The model loader in 🤗 Transformers wastes a lot of VRAM there, _do_init=False
is what prevents that. As I said in #212 just now, that option will work in newer versions of Transformers.
(As you're linking my comment -- I feel I should mention I haven't actually tried running this on a 12GB GPU. I have, however, seen VRAM use stay safely within the limits I described (subtracting graphcis system overhead) with _do_init=False
, but peak higher during loading when removing that argument)
Release: 0.1.0 / https://github.com/borisdayma/dalle-mini/commit/00d389bfa5586fde0a51e250f7ec3757bb7e704c
According to this, I should be able to load
mega-1-fp16
on a RTX 2080 Ti with 12GB VRAM. There are no other processes running on the GPU and I have only one device on this workstation.Running notebook
tools/inference/inference_pipeline.ipynb
:_do_init
commented out intentionally. See https://github.com/borisdayma/dalle-mini/issues/212OOMs with errors below.
I have tried the following prior to importing
jax
without success: