kuprel / min-dalle

min(DALL·E) is a fast, minimal port of DALL·E Mini to PyTorch
MIT License
3.48k stars 256 forks source link

Incorrect dtypes error with the Mega model #2

Closed BenEaston closed 2 years ago

BenEaston commented 2 years ago

Hey, I'm seeing the following error when passing the '--mega' option to use the mega model

(min-dalle) ➜  min-dalle git:(main) ✗ python image_from_text.py --text="a comfy chair that looks like an avocado" --mega

/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lib/__init__.py:34: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
  warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
Namespace(image_path='generated', image_token_count=256, mega=True, seed=0, text='a comfy chair that looks like an avocado', torch=False)
parsing metadata from ./pretrained/dalle_bart_mega
tokenizing text
['Ġa']
['Ġcomfy']
['Ġchair']
['Ġthat']
['Ġlooks']
['Ġlike']
['Ġan']
['Ġavocado']
text tokens [0, 58, 29872, 2408, 766, 4126, 1572, 101, 16632, 2]
loading flax encoder
encoding text tokens
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
loading flax decoder
sampling image tokens
Traceback (most recent call last):
  File "image_from_text.py", line 44, in <module>
    image = generate_image_from_text(
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/generate_image.py", line 66, in generate_image_from_text
    image_tokens[...] = generate_image_tokens_flax(
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/min_dalle_flax.py", line 70, in generate_image_tokens_flax
    image_tokens = decode_flax(
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/min_dalle_flax.py", line 49, in decode_flax
    image_tokens = decoder.sample_image_tokens(
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 254, in sample_image_tokens
    _, image_tokens = lax.scan(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 1498, in scan
    init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 1484, in _create_jaxpr
    jaxpr, consts, out_tree = _initial_style_jaxpr(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/util.py", line 219, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/util.py", line 212, in cached
    return f(*args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 82, in _initial_style_jaxpr
    jaxpr, consts, out_tree = _initial_style_open_jaxpr(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/util.py", line 219, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/util.py", line 212, in cached
    return f(*args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 76, in _initial_style_open_jaxpr
    jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1828, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1865, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 213, in sample_next_image_token
    logits, keys_state, values_state = self.apply(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 1159, in apply
    return apply(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/core/scope.py", line 831, in wrapper
    y = fn(root, *args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 1535, in scope_fn
    return fn(module.clone(parent=scope), *args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 189, in __call__
    decoder_state, (keys_state, values_state) = self.layers(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/transforms.py", line 314, in wrapped_fn
    ret = trafo_fn(module_scopes, *args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/core/lift.py", line 218, in wrapper
    y, out_variable_groups_xs_t = fn(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/core/lift.py", line 770, in inner
    broadcast_vars, (carry_vars, c), (ys, scan_vars) = scanned(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/core/axes_scan.py", line 138, in scan_fn
    _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 616, in trace_to_jaxpr_nounits
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/core/axes_scan.py", line 114, in body_fn
    broadcast_out, c, ys = fn(broadcast_in, c, *xs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/core/lift.py", line 754, in scanned
    c, y = fn(scope, c, *args)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/transforms.py", line 307, in core_fn
    res = fn(cloned, *args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 95, in __call__
    decoder_state, keys_values_state = self.self_attn(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 37, in __call__
    keys_state = lax.dynamic_update_slice(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lax/slicing.py", line 147, in dynamic_update_slice
    return dynamic_update_slice_p.bind(operand, update, *start_indices)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/core.py", line 323, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/core.py", line 326, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 166, in process_primitive
    return self.default_process_primitive(primitive, tracers, params)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 177, in default_process_primitive
    out_aval, effects = primitive.abstract_eval(*avals, **params)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/core.py", line 359, in abstract_eval_
    return abstract_eval(*args, **kwargs), no_effects
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lax/utils.py", line 67, in standard_abstract_eval
    dtype_rule(*avals, **kwargs), weak_type=weak_type,
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lax/slicing.py", line 933, in _dynamic_update_slice_dtype_rule
    lax._check_same_dtypes("dynamic_update_slice", False, operand.dtype,
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 4373, in _check_same_dtypes
    raise TypeError(msg.format(name, ", ".join(map(str, types))))
jax._src.traceback_util.UnfilteredStackTrace: TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float32, float16.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "image_from_text.py", line 44, in <module>
    image = generate_image_from_text(
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/generate_image.py", line 66, in generate_image_from_text
    image_tokens[...] = generate_image_tokens_flax(
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/min_dalle_flax.py", line 70, in generate_image_tokens_flax
    image_tokens = decode_flax(
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/min_dalle_flax.py", line 49, in decode_flax
    image_tokens = decoder.sample_image_tokens(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 254, in sample_image_tokens
    _, image_tokens = lax.scan(
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 213, in sample_next_image_token
    logits, keys_state, values_state = self.apply(
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 189, in __call__
    decoder_state, (keys_state, values_state) = self.layers(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/core/axes_scan.py", line 138, in scan_fn
    _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/core/axes_scan.py", line 114, in body_fn
    broadcast_out, c, ys = fn(broadcast_in, c, *xs)
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 95, in __call__
    decoder_state, keys_values_state = self.self_attn(
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 37, in __call__
    keys_state = lax.dynamic_update_slice(
TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float32, float16.
toabi commented 2 years ago

Changing the requirements.txt to have flax==0.4.2 seems to fix this issue.

Edit: Running on a M1 Pro in a python 3.9.13 venv.

adamryman commented 2 years ago

Just a note, I had the same issue, and I resolved by doing the following

Then generate the image using torch

 python image_from_text.py --text="a comfy chair that looks like an avocado" --seed=4 --mega --torch

--

Edit: toabi's solution above also works for me without the --torch flag. Though it is very very slow for me, as it does not seem to be able find my GPU without the --torch flag.

Ahhh, I was able to get the GPU working with jax via following their README. https://github.com/google/jax#pip-installation-gpu-cuda Turns out I did not have cuDNN installed. Though I then ran into the error: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 201326592 bytes. Guessing my computer is not beefy enough or such.

Cybergate9 commented 2 years ago

once flax_model.msgpack fixed.. then --torch --mega together works whereas --mega by itself doesn't.. trying 0.4.2 flax install now.. worked..

so for me:

(Macbook i7, OSX 12.4)

joes commented 2 years ago

Installed 0.4.2 flax

Tried

$ python3 image_from_text.py --text='a comfy chair' --mega --seed=4

but got this error:

msgpack.exceptions.ExtraData: unpack(b) received extra data.

Maybe the problem is that I am running a Intel iMac?

rklasen commented 2 years ago

Installed 0.4.2 flax

Tried

$ python3 image_from_text.py --text='a comfy chair' --mega --seed=4

but got this error:

msgpack.exceptions.ExtraData: unpack(b) received extra data.

Maybe the problem is that I am running a Intel iMac?

This can be fixed with https://github.com/kuprel/min-dalle/issues/1#issuecomment-1168242079

kuprel commented 2 years ago

Thanks for catching this. Yeah it was an issue with inconsistent dtypes that the latest flax version picked up on, but the older version didn't. The latest commit should work properly with flax 0.5.2 now.