patil-suraj / vqgan-jax

JAX implementation of VQGAN
89 stars 28 forks source link

Error VQModel.from_pretrained("valhalla/vqgan-imagenet-f16-1024") #12

Open kechan opened 2 years ago

kechan commented 2 years ago

For the cell in the notebook that instantiate the model, running on colab:

ValueError                                Traceback (most recent call last)
[<ipython-input-5-122a0a52f4ec>](https://m6v4lzxo3e-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20220902-060047-RC00_471746797#) in <module>
----> 1 model = VQModel.from_pretrained("valhalla/vqgan-imagenet-f16-1024")

[/usr/local/lib/python3.7/dist-packages/transformers/modeling_flax_utils.py](https://m6v4lzxo3e-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20220902-060047-RC00_471746797#) in from_pretrained(cls, pretrained_model_name_or_path, dtype, *model_args, **kwargs)
    853                 else:
    854                     raise ValueError(
--> 855                         f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
    856                         f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. "
    857                         "Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this "

ValueError: Trying to load the pretrained weight for ('decoder', 'mid', 'attn_1', 'norm', 'bias') failed: checkpoint has shape (1, 1, 1, 512) which is incompatible with the model shape (512,). Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this model.
vdidon commented 2 years ago

Duplicate of https://github.com/borisdayma/dalle-mini/issues/99 You can use dalle-mini/vqgan_imagenet_f16_16384