chao-ji / ldm_tf2

TensorFlow implementation of Latent Diffusion Model
4 stars 0 forks source link

Mismatch between pre-trained model and model from conversion script #1

Open kkedich opened 5 months ago

kkedich commented 5 months ago

Hi @chao-ji ,

I was trying to convert the model from pretrained txt2img model, but it seems that some shapes are different. I was able to identify for the Transformer model and part of the Unet model, but not entirely.

Example in transformer:

File "/home/kbogdan/ldm_tf2/convert_ckpt_pytorch_to_tf2.py", line 516, in main
    save_checkpoint(sd)
  File "/home/kbogdan/ldm_tf2/convert_ckpt_pytorch_to_tf2.py", line 424, in save_checkpoint
    transformer.set_weights(weights)
  File "/home/kbogdan/anaconda3/envs/ldm_tf2/lib/python3.10/site-packages/keras/src/engine/base_layer.py", line 1832, in set_weights
    raise ValueError(
ValueError: Layer transformer_model weight shape (1280, 8, 64) is not compatible with provided weight shape (640, 8, 64).

Fix for the transformer model was setting the hidden_size and filter_size to multiples of 640:

  hidden_size=640
  transformer = TransformerModel(vocab_size,
               encoder_stack_size=32,
               hidden_size=hidden_size,
               num_heads=8,
               filter_size=hidden_size*4,
               dropout_rate=0.1,)

But, the Unet model has also a mismatch from the pre-trained model. I adjusted some parts but I am having difficulties to match the exactly weights with the Unet model defined. The initial weights are being matched, but later (at weight 17, for example) the mismatch starts: sd_weight: (192,), unet_current_model_weight (192, 192)

I was wondering if you already saw this pattern from the pre-trained model or I need to define something else to follow exactly the shapes defined in the convert_ckpt_pytorch_to_tf2.py file. I double check the pre-trained model and they are in fact coming with these other shapes. Thanks

chao-ji commented 4 months ago

Hi kkedich,

I think the link to the pretrained pytorch checkpoint is wrong. The correct one is https://ommer-lab.com/files/latent-diffusion/nitro/txt2img-f8-large/model.ckpt. It's actually from the official https://github.com/CompVis/latent-diffusion/tree/main?tab=readme-ov-file#text-to-image, just below "Download the pre-trained weights (5.7GB)"