borisdayma / dalle-mini

DALL·E Mini - Generate images from a text prompt
https://www.craiyon.com
Apache License 2.0
14.75k stars 1.21k forks source link

Model requirements for mega #194

Open melih-unsal opened 2 years ago

melih-unsal commented 2 years ago

Hello, What is the memory requirement for mega-1:latest? I have 32gb ram and 16gb gpu ram but it seems it fails and gives the error below. Do i need more cpu or gpu ram ?

Traceback (most recent call last): File "dalle_server1.py", line 55, in model = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID) File "/home/ubuntu/anaconda3/envs/dalle/lib/python3.7/site-packages/dalle_mini/model/utils.py", line 26, in from_pretrained pretrained_model_name_or_path, *model_args, kwargs File "/home/ubuntu/anaconda3/envs/dalle/lib/python3.7/site-packages/transformers/modeling_flax_utils.py", line 621, in from_pretrained state = jax.tree_util.tree_map(jnp.array, state) File "/home/ubuntu/anaconda3/envs/dalle/lib/python3.7/site-packages/jax/_src/tree_util.py", line 184, in tree_map return treedef.unflatten(f(xs) for xs in zip(all_leaves)) File "/home/ubuntu/anaconda3/envs/dalle/lib/python3.7/site-packages/jax/_src/tree_util.py", line 184, in return treedef.unflatten(f(xs) for xs in zip(all_leaves)) File "/home/ubuntu/anaconda3/envs/dalle/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py", line 1865, in array out = lax_internal._convert_element_type(out, dtype, weak_type=weak_type) File "/home/ubuntu/anaconda3/envs/dalle/lib/python3.7/site-packages/jax/_src/lax/lax.py", line 522, in _convert_element_type weak_type=new_weak_type) File "/home/ubuntu/anaconda3/envs/dalle/lib/python3.7/site-packages/jax/core.py", line 323, in bind return self.bind_with_trace(find_top_trace(args), args, params) File "/home/ubuntu/anaconda3/envs/dalle/lib/python3.7/site-packages/jax/core.py", line 326, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/home/ubuntu/anaconda3/envs/dalle/lib/python3.7/site-packages/jax/core.py", line 675, in process_primitive return primitive.impl(*tracers, params) File "/home/ubuntu/anaconda3/envs/dalle/lib/python3.7/site-packages/jax/_src/dispatch.py", line 100, in apply_primitive return compiled_fun(args) File "/home/ubuntu/anaconda3/envs/dalle/lib/python3.7/site-packages/jax/_src/dispatch.py", line 151, in return lambda args, kw: compiled(*args, kw)[0] File "/home/ubuntu/anaconda3/envs/dalle/lib/python3.7/site-packages/jax/_src/dispatch.py", line 656, in _execute_trivial else h(device_put(x, device)) for h, x in zip(handlers, outs)] File "/home/ubuntu/anaconda3/envs/dalle/lib/python3.7/site-packages/jax/_src/dispatch.py", line 656, in else h(device_put(x, device)) for h, x in zip(handlers, outs)] File "/home/ubuntu/anaconda3/envs/dalle/lib/python3.7/site-packages/jax/_src/dispatch.py", line 871, in device_put return device_put_handlers[type(x)](x, device) File "/home/ubuntu/anaconda3/envs/dalle/lib/python3.7/site-packages/jax/_src/dispatch.py", line 882, in _device_put_array return (backend.buffer_from_pyval(x, device),) RuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 402653184 bytes.

drdaxxy commented 2 years ago

On my system (Windows, GeForce RTX 3090, CUDA 11.3, cuDNN 8.4.0, JAX built from source at d43cb36dae), measuring minimal fp32 mega-1 VRAM consumption with XLA_PYTHON_CLIENT_ALLOCATOR=platform:

(note, XLA_PYTHON_CLIENT_ALLOCATOR=platform can slow things down, though I haven't encountered that, and fragmentation makes it counterproductive for fitting a single application, not just saving space for others)

So unless I missed short-term peaks, your card should have more than enough.

Host memory is the bigger problem as XLA pins GPU buffers to it. I've seen active usage peak above 20 GiB, with more paged out. I also have 32 GiB of RAM, of which Windows allows sharing half with the GPU. Sending mega-1's params to the device (params = replicate(params) in the notebook) shows me lots of allocation failures but eventually succeeds.

Try increasing system swap, upgrading libraries and keeping as much RAM free as possible. Are you loading any other models (perhaps other dalle-mini versions, without restarting, before mega-1)?