lucidrains / meshgpt-pytorch

Implementation of MeshGPT, SOTA Mesh generation using Attention, in Pytorch
MIT License
700 stars 57 forks source link

Flash attention mispelling flash_attn -> attn_flash #45

Closed MarcusLoppe closed 8 months ago

MarcusLoppe commented 8 months ago

Seems like the flash attention was never implemented correctly due to a misspelling, hopefully this will improve performance .

x-transformers.py -> AttentionLayers:

 attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs)
 flash_attn = attn_kwargs.get('flash', False)
lucidrains commented 8 months ago

oh crap, yes indeed, thanks Marcus! flash attention is a necessity for long context.

MarcusLoppe commented 8 months ago

oh crap, yes indeed, thanks Marcus! flash attention is a necessity for long context.

bild https://file.io/j9FGKTio1agv @lucidrains Success! 3 200 faces mesh. Left side is the generated one. I over-fitted the models using just one mesh but it shows that it's possible with meshes with loads of faces.

Autoencoder loss: 0.29 Transformer loss: 0.0007

Metric Before After
Interference (19k tokens) 03:49 02:37
Training VRAM 17GB 10GB
lucidrains commented 8 months ago

hurray! go train the holodeck!

lucidrains commented 8 months ago

I'll be back to chip away at this later this week

MarcusLoppe commented 8 months ago

hurray! go train the holodeck!

@lucidrains

Ofc see you in 4 months, brb going to train on my 3060 :)

It seems like a couple of people have access to A100, one even had access to 4x A100. So I think someone will be able to train on a full dataset using ShapeNet & more.

But the good thing is that we have gone past the 800 face limit which all 3D mesh generator seem to have limited themself to (PolyGen, MeshGPT and PolyDiff (MeshGPT affiliated). I'd imagine it was due to the VRAM usage.

Btw I implemented bge in a forked classifier-free-guidance-pytorch,, bge is on the 7th place on the SOTA embedding leaderboard which beats openai ada embedder (23th place). I'd imagine that this would be a good model to support and have better performance then T5, I implemented the base version which has 109M parameters.

I tried using CLIP to train the transformer but it' quite bad, I tested the cosine similarity of chair and table using CLIP and it had a 99% similarity, the T5 and bge had 70% and 71%. When there is a to high similarity the transformer wont be able to tell apart from a chair or table.

I was wondering if it would be alright to do a pull request? Let me know if there's anything I need to fix beforehand. I would like to have some contributions on my github page, small chance but might come in handy if some recruiter visits my github page.

https://github.com/lucidrains/classifier-free-guidance-pytorch/commit/5c189f32dbc20cd5882be4ef2a132e2aabcb8df5