mehdidc / feed_forward_vqgan_clip

Feed forward VQGAN-CLIP model, where the goal is to eliminate the need for optimizing the latent space of VQGAN for each input prompt
MIT License
136 stars 18 forks source link

Models are broken in the new `torch` version #25

Closed neverix closed 2 years ago

neverix commented 2 years ago

PyTorch introduced approximate GELU. This breaks the MLP-Mixer models. The fix is to save pre-trained models as weight dicts and not complete pickle objects.

neverix commented 2 years ago

For now, a quick fix would be something along the lines of:

for i in net.mixer:
    if isinstance(i, torch.nn.Sequential):
        for k in i:
            k.fn[1].approximate = "none"
neverix commented 2 years ago

Guess it can be closed now

mehdidc commented 2 years ago

Thanks a lot @neverix for pointing out the issue and providing the solution! will merge to master soon

mehdidc commented 2 years ago

Merged now in master @neverix

PaulScotti commented 2 years ago

I ran into this same issue just now despite using the latest master. The original solution from Nev does work though.

mehdidc commented 2 years ago

@PaulScotti I see, could you please show the exact exception you get and with which model and pytorch version you used ?

PaulScotti commented 2 years ago

PyTorch 1.12.0+cu113 Using cc12m_32x1024_mlp_mixer_clip_ViTB32_pixelrecons_256x256_v0.4.th (v0.3 also produces the same error) AttributeError: 'GELU' object has no attribute 'approximate'

I checked main.py to make sure I am indeed using the latest script that has the "_fix_mlp_mixer_gelu_issue" edits.

Error does not occur if I use PyTorch 1.8

mehdidc commented 2 years ago

@PaulScotti Thanks for the details, I couldn't reproduce it so far not sure what I am missing, could you also please provide the line where it happens so that I could investigate more closely?

PaulScotti commented 2 years ago

Sure, attached are screenshots of the full error message: https://i.imgur.com/FbLimQa.png https://i.imgur.com/WnDY9CO.png https://i.imgur.com/FBcmzNR.png

mehdidc commented 2 years ago

Thanks @PaulScotti for the screenshots! I see now, the reason is that the _fix_mlp_mixer_gelu_issue function is not called (the older models were not modified/overriden actually, so the fix function needs to be called before using the models). If you would like to load the models manually, like in the screenshots above, you could use the function load_model from main.py (which calls the fix function internally), e.g., from main import load_model; net = load_model(<PATH>);net(emb);, should work.

The function load_model will also be compatible with new models, where only the state dict is saved, rather than the class instance, as it was done before.

PaulScotti commented 2 years ago

Ah, okay! Sorry I didnt realize that's how the load model worked, thank you for the explanation

mehdidc commented 2 years ago

Great, closing the issue now since it is solved