borisdayma / dalle-mini

DALL·E Mini - Generate images from a text prompt
https://www.craiyon.com
Apache License 2.0
14.75k stars 1.21k forks source link

CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES on a 3090 #285

Open neilbGH opened 2 years ago

neilbGH commented 2 years ago

Getting the following trying to run dalle-mini in Docker on Windows with an RTX3090 and 32GB of RAM. nvidia-smi shows the GPU is detected and I've run other GPU Docker images without issue.

wandb: Downloading large artifact mini-1:v0, 1673.43MB. 7 files... Done. 0:4:29.0
2022-07-10 13:27:22.264211: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2129] Execution of replica 0 failed: INTERNAL: Failed to launch CUDA kernel: shift_right_logical_3 with block dimensions: 1x1x1 and grid dimensions: 1x1x1: CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES: too many resources requested for launch
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Input In [5], in <cell line: 7>()
      4 from transformers import CLIPProcessor, FlaxCLIPModel
      6 # Load dalle-mini
----> 7 model, params = DalleBart.from_pretrained(
      8     DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False
      9 )
     11 # Load VQGAN
     12 vqgan, vqgan_params = VQModel.from_pretrained(
     13     VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False
     14 )

File /usr/local/lib/python3.8/dist-packages/dalle_mini/model/utils.py:25, in PretrainedFromWandbMixin.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
     22         artifact = wandb.Api().artifact(pretrained_model_name_or_path)
     23     pretrained_model_name_or_path = artifact.download(tmp_dir)
---> 25 return super(PretrainedFromWandbMixin, cls).from_pretrained(
     26     pretrained_model_name_or_path, *model_args, **kwargs
     27 )

File /usr/local/lib/python3.8/dist-packages/transformers/modeling_flax_utils.py:596, in FlaxPreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, dtype, *model_args, **kwargs)
    593     resolved_archive_file = None
    595 # init random models
--> 596 model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)
    598 if from_pt:
    599     state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file)

File /usr/local/lib/python3.8/dist-packages/transformers/models/bart/modeling_flax_bart.py:920, in FlaxBartPreTrainedModel.__init__(self, config, input_shape, seed, dtype, _do_init, **kwargs)
    910 def __init__(
    911     self,
    912     config: BartConfig,
   (...)
    917     **kwargs
    918 ):
    919     module = self.module_class(config=config, dtype=dtype, **kwargs)
--> 920     super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

File /usr/local/lib/python3.8/dist-packages/transformers/modeling_flax_utils.py:115, in FlaxPreTrainedModel.__init__(self, config, module, input_shape, seed, dtype, _do_init)
    112 self._module = module
    114 # Those are public as their type is generic to every derived classes.
--> 115 self.key = PRNGKey(seed)
    116 self.dtype = dtype
    117 self.input_shape = input_shape

File /usr/local/lib/python3.8/dist-packages/jax/_src/random.py:125, in PRNGKey(seed)
    111 """Create a pseudo-random number generator (PRNG) key given an integer seed.
    112 
    113 The resulting key carries the default PRNG implementation, as
   (...)
    122 
    123 """
    124 impl = default_prng_impl()
--> 125 key = prng.seed_with_impl(impl, seed)
    126 return _return_prng_keys(True, key)

File /usr/local/lib/python3.8/dist-packages/jax/_src/prng.py:237, in seed_with_impl(impl, seed)
    236 def seed_with_impl(impl: PRNGImpl, seed: int) -> PRNGKeyArray:
--> 237   return PRNGKeyArray(impl, impl.seed(seed))

File /usr/local/lib/python3.8/dist-packages/jax/_src/prng.py:276, in threefry_seed(seed)
    273   raise TypeError(f"PRNG key seed must be an integer; got {seed!r}")
    274 convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32), [1])
    275 k1 = convert(
--> 276     lax.shift_right_logical(seed_arr, lax_internal._const(seed_arr, 32)))
    277 with jax.numpy_dtype_promotion('standard'):
    278   # TODO(jakevdp): in X64 mode, this can generate 64-bit computations for 32-bit
    279   # inputs. We should avoid this.
    280   k2 = convert(jnp.bitwise_and(seed_arr, np.uint32(0xFFFFFFFF)))

File /usr/local/lib/python3.8/dist-packages/jax/_src/lax/lax.py:444, in shift_right_logical(x, y)
    442 def shift_right_logical(x: Array, y: Array) -> Array:
    443   r"""Elementwise logical right shift: :math:`x \gg y`."""
--> 444   return shift_right_logical_p.bind(x, y)

File /usr/local/lib/python3.8/dist-packages/jax/core.py:327, in Primitive.bind(self, *args, **params)
    324 def bind(self, *args, **params):
    325   assert (not config.jax_enable_checks or
    326           all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 327   return self.bind_with_trace(find_top_trace(args), args, params)

File /usr/local/lib/python3.8/dist-packages/jax/core.py:330, in Primitive.bind_with_trace(self, trace, args, params)
    329 def bind_with_trace(self, trace, args, params):
--> 330   out = trace.process_primitive(self, map(trace.full_raise, args), params)
    331   return map(full_lower, out) if self.multiple_results else full_lower(out)

File /usr/local/lib/python3.8/dist-packages/jax/core.py:680, in EvalTrace.process_primitive(self, primitive, tracers, params)
    679 def process_primitive(self, primitive, tracers, params):
--> 680   return primitive.impl(*tracers, **params)

File /usr/local/lib/python3.8/dist-packages/jax/_src/dispatch.py:101, in apply_primitive(prim, *args, **params)
     98 """Impl rule that compiles and runs a single primitive 'prim' using XLA."""
     99 compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
    100                                       **params)
--> 101 return compiled_fun(*args)

File /usr/local/lib/python3.8/dist-packages/jax/_src/dispatch.py:167, in xla_primitive_callable.<locals>.<lambda>(*args, **kw)
    164 compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None,
    165                                   prim.name, donated_invars, False, *arg_specs)
    166 if not prim.multiple_results:
--> 167   return lambda *args, **kw: compiled(*args, **kw)[0]
    168 else:
    169   return compiled

File /usr/local/lib/python3.8/dist-packages/jax/_src/dispatch.py:717, in _execute_compiled(name, compiled, input_handler, output_buffer_counts, result_handler, has_unordered_effects, ordered_effects, kept_var_idx, *args)
    714 if has_unordered_effects or ordered_effects:
    715   in_flat, token_handler = _add_tokens(has_unordered_effects, ordered_effects,
    716                                        device, in_flat)
--> 717 out_flat = compiled.execute(in_flat)
    718 check_special(name, out_flat)
    719 out_bufs = unflatten(out_flat, output_buffer_counts)

XlaRuntimeError: INTERNAL: Failed to launch CUDA kernel: shift_right_logical_3 with block dimensions: 1x1x1 and grid dimensions: 1x1x1: CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES: too many resources requested for launch
neilbGH commented 2 years ago

I reinstalled everything, including video card drivers and docker and it's working now.

gordicaleksa commented 2 years ago

@neilbGH are you able to step through the code with your setup?

cxhermagic commented 1 year ago

I run this project on linux , how to run it on GPU? I had the follow message.

WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

cxhermagic commented 1 year ago

Fortunately, I have resoled it with method below, :

pip install --upgrade pip

pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

zhaochunhui-0723 commented 11 months ago

jax 、jaxlib、cuda what is the corresponding version