Open melih-unsal opened 2 years ago
On my system (Windows, GeForce RTX 3090, CUDA 11.3, cuDNN 8.4.0, JAX built from source at d43cb36dae), measuring minimal fp32 mega-1
VRAM consumption with XLA_PYTHON_CLIENT_ALLOCATOR=platform
:
(note, XLA_PYTHON_CLIENT_ALLOCATOR=platform
can slow things down, though I haven't encountered that, and fragmentation makes it counterproductive for fitting a single application, not just saving space for others)
tokenized_prompt = replicate(processor([prompt] * n))
) instead of (like the notebook) processing one per device at a time, this grows by ~450 MiB per inputvqgan.decode_code
takes 3 seconds for one image in host memory on a Ryzen 3700X, so that's an option.So unless I missed short-term peaks, your card should have more than enough.
Host memory is the bigger problem as XLA pins GPU buffers to it. I've seen active usage peak above 20 GiB, with more paged out. I also have 32 GiB of RAM, of which Windows allows sharing half with the GPU. Sending mega-1
's params
to the device (params = replicate(params)
in the notebook) shows me lots of allocation failures but eventually succeeds.
Try increasing system swap, upgrading libraries and keeping as much RAM free as possible. Are you loading any other models (perhaps other dalle-mini versions, without restarting, before mega-1
)?
Hello, What is the memory requirement for mega-1:latest? I have 32gb ram and 16gb gpu ram but it seems it fails and gives the error below. Do i need more cpu or gpu ram ?
Traceback (most recent call last): File "dalle_server1.py", line 55, in
model = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)
File "/home/ubuntu/anaconda3/envs/dalle/lib/python3.7/site-packages/dalle_mini/model/utils.py", line 26, in from_pretrained
pretrained_model_name_or_path, *model_args, kwargs
File "/home/ubuntu/anaconda3/envs/dalle/lib/python3.7/site-packages/transformers/modeling_flax_utils.py", line 621, in from_pretrained
state = jax.tree_util.tree_map(jnp.array, state)
File "/home/ubuntu/anaconda3/envs/dalle/lib/python3.7/site-packages/jax/_src/tree_util.py", line 184, in tree_map
return treedef.unflatten(f(xs) for xs in zip(all_leaves))
File "/home/ubuntu/anaconda3/envs/dalle/lib/python3.7/site-packages/jax/_src/tree_util.py", line 184, in
return treedef.unflatten(f(xs) for xs in zip(all_leaves))
File "/home/ubuntu/anaconda3/envs/dalle/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py", line 1865, in array
out = lax_internal._convert_element_type(out, dtype, weak_type=weak_type)
File "/home/ubuntu/anaconda3/envs/dalle/lib/python3.7/site-packages/jax/_src/lax/lax.py", line 522, in _convert_element_type
weak_type=new_weak_type)
File "/home/ubuntu/anaconda3/envs/dalle/lib/python3.7/site-packages/jax/core.py", line 323, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/ubuntu/anaconda3/envs/dalle/lib/python3.7/site-packages/jax/core.py", line 326, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/ubuntu/anaconda3/envs/dalle/lib/python3.7/site-packages/jax/core.py", line 675, in process_primitive
return primitive.impl(*tracers, params)
File "/home/ubuntu/anaconda3/envs/dalle/lib/python3.7/site-packages/jax/_src/dispatch.py", line 100, in apply_primitive
return compiled_fun(args)
File "/home/ubuntu/anaconda3/envs/dalle/lib/python3.7/site-packages/jax/_src/dispatch.py", line 151, in
return lambda args, kw: compiled(*args, kw)[0]
File "/home/ubuntu/anaconda3/envs/dalle/lib/python3.7/site-packages/jax/_src/dispatch.py", line 656, in _execute_trivial
else h(device_put(x, device)) for h, x in zip(handlers, outs)]
File "/home/ubuntu/anaconda3/envs/dalle/lib/python3.7/site-packages/jax/_src/dispatch.py", line 656, in
else h( device_put(x, device)) for h, x in zip(handlers, outs)]
File "/home/ubuntu/anaconda3/envs/dalle/lib/python3.7/site-packages/jax/_src/dispatch.py", line 871, in device_put
return device_put_handlers[type(x)](x, device)
File "/home/ubuntu/anaconda3/envs/dalle/lib/python3.7/site-packages/jax/_src/dispatch.py", line 882, in _device_put_array
return (backend.buffer_from_pyval(x, device),)
RuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 402653184 bytes.