lucidrains / meshgpt-pytorch

Implementation of MeshGPT, SOTA Mesh generation using Attention, in Pytorch
MIT License
700 stars 57 forks source link

Simple training script for toy data? #46

Open xiao-xian opened 8 months ago

xiao-xian commented 8 months ago

Hi there, I wonder if it's possible to have some script reproducing the same toy example from an older paper. I tried to run the training, but the best thing I came up with is this: image I also constantly run into NaN as reported here. Thanks for any help!

lucidrains commented 8 months ago

could you retry? there was an issue uncovered by Farid that has since been resolved

MarcusLoppe commented 8 months ago

Hi there, I wonder if it's possible to have some script reproducing the same toy example from an older paper. I tried to run the training, but the best thing I came up with is this:

Hi @xiao-xian

I uploaded a demo notebook at my fork: https://github.com/MarcusLoppe/meshgpt-pytorch/tree/main

It's tricky using such a small dataset I've have had more luck using 30-40 models per label since it can generalize better. But try to see if you can get the transformer loss to 0.0001 or something very low.

Also; if you don't generate the codes before you train the transformer; the autoencoder will generate and waste 4-96G VRAM each training step since the codes it generates deletes itself due to dataloader. So pre-generating them and storing in the dataset gives much better speed & VRAM usage.

xiao-xian commented 7 months ago

Many thanks @MarcusLoppe!! I pull your branch and use the above notebook to run the training. The training loss for encoder is around 0.28: image However for transformer, it never reached to anything lower than 0.01. And when generating those demo meshes: image They all kind of collapsed to one single cone: image Not sure if I missed anything obvious. What kind of meshes you got from the above script? Many thanks!

MarcusLoppe commented 7 months ago

@xiao-xian

Ah, yes the transformer seem to have trouble at the first token sometimes. It's due to the text doesn't guide it very well when there is no tokens. This issue resolves when using many more mesh models but is a issue when dealing with small amount of meshes. If you provide it just 1 token as a prompt to the generate function it manages to generate the meshes without a problem since the text does a better job at guiding the transformer when it has some data to work with.

I seems to also have some trouble with meshes with very small amount of triangles, the box only has 12 triangles and it always had some trouble while the 112 triangles meshes where fine.

Try with some meshes that are bit more 'complex', here is 4 tables and 4 chairs that works pretty well, apply 50 augmentations per model for bit more robust generalization. https://easyupload.io/hosas0

lucidrains commented 7 months ago

@MarcusLoppe oh interesting, maybe i could have an option to make the first token generated unconditionally? what do you think?

MarcusLoppe commented 7 months ago

@MarcusLoppe oh interesting, maybe i could have an option to make the first token generated unconditionally? what do you think?

It kinda seems like it's already doing, I'm guessing it's due to that the cross-attention impact isn't very high when there is no data/tokens.

I've tried setting the only_cross to True but it doesn't have a noticeable impact on the problem.