Running python image_from_text.py --text='a comfy chair' --seed=7 shows the following error:
$ python image_from_text.py --text='a comfy chair' --seed=7
Namespace(mega=False, torch=False, text='a comfy chair', seed=7, image_path='generated', image_token_count=256)
parsing metadata from ./pretrained/dalle_bart_mini
tokenizing text
['Ġa']
['Ġcomfy']
['Ġchair']
text tokens [0, 58, 29872, 2408, 2]
loading flax encoder
encoding text tokens
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
loading flax decoder
sampling image tokens
Traceback (most recent call last):
File "/home/papul/min-dalle/image_from_text.py", line 44, in <module>
image = generate_image_from_text(
File "/home/papul/min-dalle/min_dalle/generate_image.py", line 66, in generate_image_from_text
image_tokens[...] = generate_image_tokens_flax(
File "/home/papul/min-dalle/min_dalle/min_dalle_flax.py", line 70, in generate_image_tokens_flax
image_tokens = decode_flax(
File "/home/papul/min-dalle/min_dalle/min_dalle_flax.py", line 49, in decode_flax
image_tokens = decoder.sample_image_tokens(
File "/home/papul/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 255, in sample_image_tokens
_, image_tokens = lax.scan(
File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/lax/control_flow.py", line 1498, in scan
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/lax/control_flow.py", line 1484, in _create_jaxpr
jaxpr, consts, out_tree = _initial_style_jaxpr(
File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/util.py", line 219, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/util.py", line 212, in cached
return f(*args, **kwargs)
File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/lax/control_flow.py", line 82, in _initial_style_jaxpr
jaxpr, consts, out_tree = _initial_style_open_jaxpr(
File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/util.py", line 219, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/util.py", line 212, in cached
return f(*args, **kwargs)
File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/lax/control_flow.py", line 76, in _initial_style_open_jaxpr
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/home/papul/.local/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1828, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/papul/.local/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1865, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/home/papul/.local/lib/python3.10/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/papul/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 214, in sample_next_image_token
logits, keys_state, values_state = self.apply(
File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/module.py", line 1159, in apply
return apply(
File "/home/papul/.local/lib/python3.10/site-packages/flax/core/scope.py", line 831, in wrapper
y = fn(root, *args, **kwargs)
File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/module.py", line 1535, in scope_fn
return fn(module.clone(parent=scope), *args, **kwargs)
File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
return prewrapped_fn(self, *args, **kwargs)
File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/home/papul/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 189, in __call__
decoder_state, (keys_state, values_state) = self.layers(
File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
return prewrapped_fn(self, *args, **kwargs)
File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/transforms.py", line 314, in wrapped_fn
ret = trafo_fn(module_scopes, *args, **kwargs)
File "/home/papul/.local/lib/python3.10/site-packages/flax/core/lift.py", line 218, in wrapper
y, out_variable_groups_xs_t = fn(
File "/home/papul/.local/lib/python3.10/site-packages/flax/core/lift.py", line 770, in inner
broadcast_vars, (carry_vars, c), (ys, scan_vars) = scanned(
File "/home/papul/.local/lib/python3.10/site-packages/flax/core/axes_scan.py", line 138, in scan_fn
_, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/home/papul/.local/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 616, in trace_to_jaxpr_nounits
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/home/papul/.local/lib/python3.10/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/papul/.local/lib/python3.10/site-packages/flax/core/axes_scan.py", line 114, in body_fn
broadcast_out, c, ys = fn(broadcast_in, c, *xs)
File "/home/papul/.local/lib/python3.10/site-packages/flax/core/lift.py", line 754, in scanned
c, y = fn(scope, c, *args)
File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/transforms.py", line 307, in core_fn
res = fn(cloned, *args, **kwargs)
File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
return prewrapped_fn(self, *args, **kwargs)
File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/home/papul/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 95, in __call__
decoder_state, keys_values_state = self.self_attn(
File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
return prewrapped_fn(self, *args, **kwargs)
File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/home/papul/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 37, in __call__
keys_state = lax.dynamic_update_slice(
File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/lax/slicing.py", line 147, in dynamic_update_slice
return dynamic_update_slice_p.bind(operand, update, *start_indices)
File "/home/papul/.local/lib/python3.10/site-packages/jax/core.py", line 323, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/papul/.local/lib/python3.10/site-packages/jax/core.py", line 326, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/papul/.local/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 166, in process_primitive
return self.default_process_primitive(primitive, tracers, params)
File "/home/papul/.local/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 177, in default_process_primitive
out_aval, effects = primitive.abstract_eval(*avals, **params)
File "/home/papul/.local/lib/python3.10/site-packages/jax/core.py", line 359, in abstract_eval_
return abstract_eval(*args, **kwargs), no_effects
File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/lax/utils.py", line 67, in standard_abstract_eval
dtype_rule(*avals, **kwargs), weak_type=weak_type,
File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/lax/slicing.py", line 933, in _dynamic_update_slice_dtype_rule
lax._check_same_dtypes("dynamic_update_slice", False, operand.dtype,
File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 4373, in _check_same_dtypes
raise TypeError(msg.format(name, ", ".join(map(str, types))))
jax._src.traceback_util.UnfilteredStackTrace: TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float16, float32.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/papul/min-dalle/image_from_text.py", line 44, in <module>
image = generate_image_from_text(
File "/home/papul/min-dalle/min_dalle/generate_image.py", line 66, in generate_image_from_text
image_tokens[...] = generate_image_tokens_flax(
File "/home/papul/min-dalle/min_dalle/min_dalle_flax.py", line 70, in generate_image_tokens_flax
image_tokens = decode_flax(
File "/home/papul/min-dalle/min_dalle/min_dalle_flax.py", line 49, in decode_flax
image_tokens = decoder.sample_image_tokens(
File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
return prewrapped_fn(self, *args, **kwargs)
File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/home/papul/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 255, in sample_image_tokens
_, image_tokens = lax.scan(
File "/home/papul/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 214, in sample_next_image_token
logits, keys_state, values_state = self.apply(
File "/home/papul/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 189, in __call__
decoder_state, (keys_state, values_state) = self.layers(
File "/home/papul/.local/lib/python3.10/site-packages/flax/core/axes_scan.py", line 138, in scan_fn
_, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
File "/home/papul/.local/lib/python3.10/site-packages/flax/core/axes_scan.py", line 114, in body_fn
broadcast_out, c, ys = fn(broadcast_in, c, *xs)
File "/home/papul/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 95, in __call__
decoder_state, keys_values_state = self.self_attn(
File "/home/papul/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 37, in __call__
keys_state = lax.dynamic_update_slice(
TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float16, float32.
Running
python image_from_text.py --text='a comfy chair' --seed=7
shows the following error: