borisdayma / dalle-mini

DALL·E Mini - Generate images from a text prompt
https://www.craiyon.com
Apache License 2.0
14.71k stars 1.2k forks source link

There seems to be an internal error when running the Colab "Inference Pipeline" example. #110

Closed doyeonyeah closed 2 years ago

doyeonyeah commented 2 years ago

This is the full output when I ran the code for loading model, after installing all of the prerequisites (transformer version=0.3.5) on Colab.

AttributeError: 'BartConfig' object has no attribute 'image_vocab_size' - I wasn't able to find any examples of this error on the net.


UnfilteredStackTrace Traceback (most recent call last)

in () 2 tokenizer = BartTokenizer.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID) ----> 3 model = CustomFlaxBartForConditionalGeneration.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID) /usr/local/lib/python3.7/dist-packages/transformers/modeling_flax_utils.py in from_pretrained(cls, pretrained_model_name_or_path, dtype, *model_args, **kwargs) 350 # init random models --> 351 model = cls(config, *model_args, **model_kwargs) 352 /usr/local/lib/python3.7/dist-packages/transformers/models/bart/modeling_flax_bart.py in __init__(self, config, input_shape, seed, dtype, **kwargs) 928 module = self.module_class(config=config, dtype=dtype, **kwargs) --> 929 super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) 930 /usr/local/lib/python3.7/dist-packages/transformers/modeling_flax_utils.py in __init__(self, config, module, input_shape, seed, dtype) 105 # randomly initialized parameters --> 106 random_params = self.init_weights(self.key, input_shape) 107 /usr/local/lib/python3.7/dist-packages/transformers/models/bart/modeling_flax_bart.py in init_weights(self, rng, input_shape) 953 position_ids, --> 954 decoder_position_ids, 955 )["params"] /usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs) 161 try: --> 162 return fun(*args, **kwargs) 163 except Exception as e: /usr/local/lib/python3.7/dist-packages/flax/linen/module.py in init(self, rngs, method, mutable, *args, **kwargs) 1122 rngs, *args, -> 1123 method=method, mutable=mutable, **kwargs) 1124 return v_out /usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs) 161 try: --> 162 return fun(*args, **kwargs) 163 except Exception as e: /usr/local/lib/python3.7/dist-packages/flax/linen/module.py in init_with_output(self, rngs, method, mutable, *args, **kwargs) 1090 return self.apply( -> 1091 {}, *args, rngs=rngs, method=method, mutable=mutable, **kwargs) 1092 /usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs) 161 try: --> 162 return fun(*args, **kwargs) 163 except Exception as e: /usr/local/lib/python3.7/dist-packages/flax/linen/module.py in apply(self, variables, rngs, method, mutable, capture_intermediates, *args, **kwargs) 1059 mutable=mutable, capture_intermediates=capture_intermediates -> 1060 )(variables, *args, **kwargs, rngs=rngs) 1061 /usr/local/lib/python3.7/dist-packages/flax/core/scope.py in wrapper(variables, rngs, *args, **kwargs) 690 with bind(variables, rngs=rngs, mutable=mutable).temporary() as root: --> 691 y = fn(root, *args, **kwargs) 692 if mutable is not False: /usr/local/lib/python3.7/dist-packages/flax/linen/module.py in scope_fn(scope, *args, **kwargs) 1311 try: -> 1312 return fn(module.clone(parent=scope), *args, **kwargs) 1313 finally: /usr/local/lib/python3.7/dist-packages/flax/linen/transforms.py in wrapped_fn(self, *args, **kwargs) 601 if not force and not linen_module._use_named_call: --> 602 return prewrapped_fn(self, *args, **kwargs) 603 fn_name = class_fn.__name__ /usr/local/lib/python3.7/dist-packages/flax/linen/module.py in wrapped_module_method(*args, **kwargs) 317 self, args = args[0], args[1:] --> 318 return self._call_wrapped_method(fun, args, kwargs) 319 else: /usr/local/lib/python3.7/dist-packages/flax/linen/module.py in _call_wrapped_method(self, fun, args, kwargs) 592 else: --> 593 self._try_setup() 594 /usr/local/lib/python3.7/dist-packages/flax/linen/module.py in _try_setup(self, shallow) 788 if not shallow: --> 789 self.setup() 790 finally: /usr/local/lib/python3.7/dist-packages/flax/linen/module.py in wrapped_module_method(*args, **kwargs) 317 self, args = args[0], args[1:] --> 318 return self._call_wrapped_method(fun, args, kwargs) 319 else: /usr/local/lib/python3.7/dist-packages/flax/linen/module.py in _call_wrapped_method(self, fun, args, kwargs) 601 try: --> 602 y = fun(self, *args, **kwargs) 603 if _context.capture_stack: /usr/local/lib/python3.7/dist-packages/dalle_mini/model.py in setup(self) 50 self.lm_head = nn.Dense( ---> 51 self.config.image_vocab_size + 1, # encoded image token space + 1 for bos 52 use_bias=False, /usr/local/lib/python3.7/dist-packages/transformers/configuration_utils.py in __getattribute__(self, key) 236 key = super().__getattribute__("attribute_map")[key] --> 237 return super().__getattribute__(key) 238 UnfilteredStackTrace: AttributeError: 'BartConfig' object has no attribute 'image_vocab_size' The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified. -------------------- The above exception was the direct cause of the following exception: AttributeError Traceback (most recent call last) in () 1 # set up tokenizer and model 2 tokenizer = BartTokenizer.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID) ----> 3 model = CustomFlaxBartForConditionalGeneration.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID) /usr/local/lib/python3.7/dist-packages/transformers/modeling_flax_utils.py in from_pretrained(cls, pretrained_model_name_or_path, dtype, *model_args, **kwargs) 349 350 # init random models --> 351 model = cls(config, *model_args, **model_kwargs) 352 353 if from_pt: /usr/local/lib/python3.7/dist-packages/transformers/models/bart/modeling_flax_bart.py in __init__(self, config, input_shape, seed, dtype, **kwargs) 927 ): 928 module = self.module_class(config=config, dtype=dtype, **kwargs) --> 929 super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) 930 931 def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: /usr/local/lib/python3.7/dist-packages/transformers/modeling_flax_utils.py in __init__(self, config, module, input_shape, seed, dtype) 104 105 # randomly initialized parameters --> 106 random_params = self.init_weights(self.key, input_shape) 107 108 # save required_params as set /usr/local/lib/python3.7/dist-packages/transformers/models/bart/modeling_flax_bart.py in init_weights(self, rng, input_shape) 952 decoder_attention_mask, 953 position_ids, --> 954 decoder_position_ids, 955 )["params"] 956 /usr/local/lib/python3.7/dist-packages/dalle_mini/model.py in setup(self) 49 self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype) 50 self.lm_head = nn.Dense( ---> 51 self.config.image_vocab_size + 1, # encoded image token space + 1 for bos 52 use_bias=False, 53 kernel_init=jax.nn.initializers.normal(self.config.init_std), /usr/local/lib/python3.7/dist-packages/transformers/configuration_utils.py in __getattribute__(self, key) 235 if key != "attribute_map" and key in super().__getattribute__("attribute_map"): 236 key = super().__getattribute__("attribute_map")[key] --> 237 return super().__getattribute__(key) 238 239 def __init__(self, **kwargs): AttributeError: 'BartConfig' object has no attribute 'image_vocab_size'
borisdayma commented 2 years ago

Oh sorry about it. I'm updating the model. You can either install from an older commit (like a month old) or wait maybe a week for the new model.

doyeonyeah commented 2 years ago

@borisdayma Thanks, it works perfectly with the commit 31b52fa675175a828e4a5d180e8016fa630d67a from 21/11/17.

I'm looking forward to the newest update. Love your work & please keep up the great work!

borisdayma commented 2 years ago

This has now been fixed. Note: we should have a new inference notebook soon that is much faster.

zifken commented 2 years ago

When running the inference pipeline notebook on Google Collab (free version) an error is raised when loading DALLE-mini model:

# Load dalle-mini
model = DalleBart.from_pretrained(
    DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=dtype, abstract_init=True
)

Error :

wandb: Downloading large artifact model-3e2l7fxk:latest, 1625.33MB. 7 files... Done. 0:0:0

---------------------------------------------------------------------------

AssertionError                            Traceback (most recent call last)

[/usr/local/lib/python3.7/dist-packages/dalle_mini/model/configuration.py](https://localhost:8080/#) in __init__(self, normalize_text, encoder_vocab_size, image_vocab_size, image_length, max_text_length, encoder_layers, encoder_ffn_dim, encoder_attention_heads, decoder_layers, decoder_ffn_dim, decoder_attention_heads, activation_function, d_model, dropout, attention_dropout, activation_dropout, init_std, scale_embedding, gradient_checkpointing, use_cache, is_encoder_decoder, forced_eos_token_id, tie_word_embeddings, do_sample, use_bias, ln_type, ln_positions, use_head_scale, use_cosine_attention, tau_init, use_deepnet_scaling, use_glu, use_alibi, sinkhorn_iters, use_final_ln_encoder, use_final_ln_decoder, force_ln_scale, **kwargs)
    106             assert (
    107                 use_final_ln_encoder
--> 108             ), "use_final_ln_encoder must be True when ln_positions is 'postln'"
    109             assert (
    110                 use_final_ln_decoder

AssertionError: use_final_ln_encoder must be True when ln_positions is 'postln'
borisdayma commented 2 years ago

Thanks for the feedback. It should now work.