lucidrains / meshgpt-pytorch

Implementation of MeshGPT, SOTA Mesh generation using Attention, in Pytorch
MIT License
700 stars 57 forks source link

Fine gateloop doesnt use the param gateloop_use_heinsen #87

Closed MarcusLoppe closed 2 months ago

lucidrains commented 2 months ago

@MarcusLoppe if you think gateloop helps, let's add it!

MarcusLoppe commented 2 months ago

@MarcusLoppe if you think gateloop helps, let's add it!

The gateloop actually works, the heinsen works even better. Here are some of my notes of my experiments. I'll check again but the best results is using heinsen for just the fine gateloop, however it might be just some random variance with the training.

float32 heinsen eps: 1e-20 fine gateloop_use_heinsen = True course gateloop_use_heinsen = False

coarse_pre_gateloop_depth =2,
fine_pre_gateloop_depth= 2, Epoch 15 average loss: 0.6862348544406381 Epoch 3 average loss: 0.37924969778022666

coarse_pre_gateloop_depth =0,
fine_pre_gateloop_depth= 2, Epoch 15 average loss: 0.8267558821063629 Epoch 15 average loss: 0.04371300131520804

coarse_pre_gateloop_depth =2,
fine_pre_gateloop_depth= 0,

Epoch 15 average loss: 1.742983629359281

fine gateloop_use_heinsen = True course gateloop_use_heinsen = True

Epoch 15 average loss: 0.84913022887898

float16 heinse eps: 1e-7

fine gateloop_use_heinsen = True course gateloop_use_heinsen = True

coarse_pre_gateloop_depth =2,
fine_pre_gateloop_depth= 2,

Epoch 15 average loss: 0.733044471651475

fine gateloop_use_heinsen = False course gateloop_use_heinsen = False

coarse_pre_gateloop_depth =2,
fine_pre_gateloop_depth= 2,

Epoch 15 average loss: 1.1578134642565314

lucidrains commented 2 months ago

@MarcusLoppe wow, thanks! that's surprising because the tokens for the fine transformer aren't very long at all

MarcusLoppe commented 2 months ago

@MarcusLoppe wow, thanks! that's surprising because the tokens for the fine transformer aren't very long at all

@lucidrains Oh ye, i was thinking of making a pull request or issue about it.

I actually get much better training progression if I rearrange b nf n dinto b (nf n) d instead of (b nf) n d. Any idea why and what is the benefit of using (b nf) n d ? :)

lucidrains commented 2 months ago

@MarcusLoppe that was mostly for the hierarchical transformer (the fine stage attends to each token within each face), but for gateloop you are right, it could make sense to just have it act on the whole sequence

MarcusLoppe commented 2 months ago

@MarcusLoppe that was mostly for the hierarchical transformer (the fine stage attends to each token within each face), but for gateloop you are right, it could make sense to just have it act on the whole sequence

I replace the the line: L1657 with: fine_vertex_codes = rearrange(fine_vertex_codes, 'b nf n d -> b (nf n) d')

But for the fine gateloop it already proccess the whole sequence since it rearranges to b (nf n) d and then back to b nf n d for the fine_decoders sake. So the change should only affect the fine_decoder.

Little off topic but you were right that MeshGPT would be a stepping stone for 3D mesh generation since there have been many autoregressive 3D projects/papers thanks to you ❤️

Have you looked into the MeshGPT clone repos MeshAnything pivotmesh to see if they have any interesting changes?

lucidrains commented 2 months ago

@MarcusLoppe haha, well the real OG would be Yawar Siddiqui for all the clever tricks he discovered for encoding the mesh. But also you for helping with all the training, feedback, debugging! I don't think this repository would have ever taken off and spawned off downstream work without your help

No not yet, but I will gather up a few and read them probably in a month or two

MarcusLoppe commented 2 months ago

@MarcusLoppe haha, well the real OG would be Yawar Siddiqui for all the clever tricks he discovered for encoding the mesh. But also you for helping with all the training, feedback, debugging! I don't think this repository would have ever taken off and spawned off downstream work without your help

No not yet, but I will gather up a few and read them probably in a month or two

Very true :) Although they said they haven't released the source code due they are getting legal approval. I'm very hopeful that we can get big enougth model published, I think scaling up the model will resolve many issues. I've collected over 200k 3D models with captions, I got around 100k models with max 1k triangles and 150k with max2k triangles. I'm hopeful that we are able to do 3 tokens per triangle, all the meshgpt clones uses 9 or 10 tokens per triangle so currently MeshGPT is the winner with 6 tokens but hopefully they'll get inspired to improve further.

Using 3k triangles would result in 9k token length, hopefully the PPL score wont get too high :)

lucidrains commented 2 months ago

pretty sure I'll see your name on some paper within a year 🤣

MarcusLoppe commented 2 months ago

@lucidrains

Hey,

What would you say if we were to remove the rearranging to so the fine_decoder inputs b (nf n) d instead of (b nf) n d ? I've attached a benchmark below using 1k labels with x3 unique models for each label (only using the first 60 tokens).

I could make a pull request but not quite sure how to fix the cache code, looks like I can just remove: ck, cv = [rearrange(t, '(b nf) ... -> b nf ...', b = batch) for t in (ck, cv)] ?

if exists(fine_cache):
    for attn_intermediate in fine_cache.attn_intermediates:
        ck, cv = attn_intermediate.cached_kv
        ck, cv = [rearrange(t, '(b nf) ... -> b nf ...', b = batch) for t in (ck, cv)]

        # when operating on the cached key / values, treat self attention and cross attention differently

        layer_type = attn_intermediate.layer_type

        if layer_type == 'a':
            ck, cv = [t[:, -1, :, :curr_vertex_pos] for t in (ck, cv)]
        elif layer_type == 'c':
            ck, cv = [t[:, -1, ...] for t in (ck, cv)]

        attn_intermediate.cached_kv = (ck, cv)

Using b (nf n) d

Epoch 1 average loss: 7.510405244674275
Epoch 2 average loss: 7.32208492156656
Epoch 3 average loss: 7.142376790072191
Epoch 4 average loss: 6.821965100293491  
Epoch 5 average loss: 6.167413803345379
Epoch 6 average loss: 5.379516076276647 
Epoch 7 average loss: 4.52265319594725 
Epoch 8 average loss: 3.6745032562929043   
Epoch 9 average loss: 2.8920027202463405   
Epoch 10 average loss: 2.2208851180612084  
Epoch 11 average loss: 1.7724002178977518
Epoch 12 average loss: 1.342862301013049  
Epoch 13 average loss: 1.064506171540143 
Epoch 14 average loss: 0.8560295670746482  
Epoch 15 average loss: 0.7038816100454586 

Epoch 1 average loss: 0.8444444752313236
Epoch 2 average loss: 0.5591754216083231
Epoch 3 average loss: 0.27806240949719985 
Epoch 4 average loss: 0.15000288555010116  
Epoch 5 average loss: 0.10225904246861922  
Epoch 6 average loss: 0.07386283096384237  
Epoch 7 average loss: 0.059430753583098475 
Epoch 8 average loss: 0.053663207108483595 
Epoch 9 average loss: 0.048432830123499755  
Epoch 10 average loss: 0.046043179082599556  
Epoch 11 average loss: 0.04352023722454507   
Epoch 12 average loss: 0.04212538892014779   
Epoch 13 average loss: 0.04111959553139414   
Epoch 14 average loss: 0.03909969734237156   

Using: (b nf) n d

Epoch 1 average loss: 7.518533711764902
Epoch 2 average loss: 7.324829603898971
Epoch 3 average loss: 7.217989939419343
Epoch 4 average loss: 7.028948023994976  
Epoch 5 average loss: 6.554357107947855 
Epoch 6 average loss: 5.851589524172207 
Epoch 7 average loss: 5.136828820335674  
Epoch 8 average loss: 4.449877207291955 
Epoch 9 average loss: 3.75566021898851   
Epoch 10 average loss: 3.1004658625087638 
Epoch 11 average loss: 2.532497930016747 
Epoch 12 average loss: 2.022928041570327  
Epoch 13 average loss: 1.6278170920948294 
Epoch 14 average loss: 1.3055970888724302   
Epoch 15 average loss: 1.0916560069124965        

Epoch 1 average loss: 1.343507925456858
Epoch 2 average loss: 1.0067555765735912
Epoch 3 average loss: 0.6992494832066929
Epoch 4 average loss: 0.5151613274997568
Epoch 5 average loss: 0.3513614073156673 
Epoch 6 average loss: 0.23698255204580684  
Epoch 7 average loss: 0.17033607433027126 
Epoch 8 average loss: 0.13577893997896165   
Epoch 9 average loss: 0.11693950838105564   
Epoch 10 average loss: 0.10296207325742207   
Epoch 11 average loss: 0.09114892499889919  
Epoch 13 average loss: 0.07443246070076437     
Epoch 14 average loss: 0.06872947679404269  
Epoch 15 average loss: 0.0655904230985412   
lucidrains commented 2 months ago

@MarcusLoppe hey Marcus, i think it is difficult to get the caching right

what about just increasing the number of layers of gateloop in the coarse stage?

MarcusLoppe commented 2 months ago

@MarcusLoppe hey Marcus, i think it is difficult to get the caching right

what about just increasing the number of layers of gateloop in the coarse stage?

Is there any particular reason why you want to use (b nf)? I had some issues training on 9k labels and the loss got stuck around 0.2 (with 0.001 increments per epoch), but then I implemented (nf n) then the progress got alot faster and was able to get to 0.001 loss.

I rather not since the gateloop layers has some compute cost, using 2 fine & 2 coarse increases the epoch time by around 25%. I think the trade off is optimal around 1 or 2 layers then there is some diminishing returns.

lucidrains commented 2 months ago

ah, yeah, I can revisit this at a later date

added ability for full sequence gateloops just preceding the fine transformer, easiest thing that can be done for now

MarcusLoppe commented 2 months ago

ah, yeah, I can revisit this at a later date

added ability for full sequence gateloops just preceding the fine transformer, easiest thing that can be done for now

If I do some testing and get the cache to work, would you implement it or do you doubt if it's better? :)

I'll give it a go, however you wrote the wrong layer name , you are calling the same course gateloop twice.

lucidrains commented 2 months ago

@MarcusLoppe haha, no i have no doubts! if you can get it working, would be happy to merge it in

oops, let me quick fix

lucidrains commented 2 months ago

@MarcusLoppe set you up with a test btw. as long as this passes you are golden

MarcusLoppe commented 2 months ago

@MarcusLoppe set you up with a test btw. as long as this passes you are golden

That was bit more complicated then I thought, I'll need to have a think on that.

But I tested using 2x post course gateloop but that made the performance slower and worse :/

MarcusLoppe commented 2 months ago

@lucidrains

Could I get some advice before rasing a issue? All the new ai GPU's have loads of fp16 compute, for example, if i train using float32 on 8 h100's it takes 6hrs per epoch (2.5B dataset) but If I use fp16 it takes 3hrs since they got x2 compute in fp16.

So some lessons I've learned is that the coarse adaptive rmsnorm and gateloop (very fast when using heinsen) makes the loss go nan very fast. The time it takes for it go nan wihout the gateloop Is a little slower but will happen within a epoch.

The reason why gateloop makes it go nan is probably due to different high precision math expressions and dtypes (complex64). log_kv = abs_clamp_eps(kv, eps = eps).to(dtype = torch.complex64).log() I've not seen great or any improvement using adaptive rmsnorm when testing it (although it was shorter sequences) so I haven't looked into debugging it.

I've only a little more then a week left to use the h100's so instead of dragging my ass through debugging the code for a week I'd figure I might ask you :) Do you have any advice for me? I was thinking it might also have to do with the sequence length, I'm using the transformer with 550M parameters and 6k sequence length. Is there something I should enable in the transformer decoder layer that might help with the nan loss or the length challenge?

MarcusLoppe commented 2 months ago

@lucidrains

Hey, congrats with the success of alphafold3 :) Quite surprised when they mentioned you in "Last week in ai"-podcast

So I got some new of my own, I manage to train the autoencoder & transformers using only 1 quantizer on meshes with 700-1000 triangles . That means that each triangle only need 3 tokens per triangle, this blows away the other meshgpt-like projects since they use at least 9 tokens per triangle. I've also trained a 2k triangle autoencoder and it reached around 0.35 reconstruction loss which is very good.

So for instance a 800 triangle mesh, MeshXL would require a sequence length of 7200, but this needs only 2400 tokens! <3 The inference is relatively fast as well, I was able to generate meshes with 1000 triangles within 20s (h100) (40s with 4090).

I found out that the persistent memory (num_mem_kv) & 64 sos tokens was quite useful when dealing with higher detailed meshes.

bild

There is a issue regarding the text conditioning, I'd imagine that you implemented text_condition_cond_drop_prob to make it more robust. So for example if I train using the label 'a chair' then it has a chance to drop the dim that represent "a" and make it possible to train the model to be able to output that item when providing the text "chair". However the embeddings isn't quite like that, if I encode "a chair"(2,768) and "chair"(1,768), then do a cosine similarity on the dim which belongs to the "chair" token e.g ([0] and [1]), the similarity high (0.9) but not 1.0.

This will then cause issues when generating the mesh, if you provide it the text "chair" and it's never trained on that exact embedding and only on the label "a chair" it will fail even if you have trained with the text_condition_cond_drop_prob. The optimal goal would be that the transformer can understand the characteristic of the embedding and find the pattern for "chair", but that would probability require huge amounts of texts since you'll basicity training a text semantic model.

Please let me know if you have any idea :) ? The only viable idea would be to input the tokenized text instead of text_embeddings then drop the tokens with the probability mask and then encode again, however this would results in a large amount of extra compute. Or you can make the dataset store all the variations of the text embedding and pick one at random.

Here are some tests, I've tried many different but paraphrase-multilingual-MiniLM-L12-v2 seems to do the with small syntax differences but not as well as semantic words.

sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2 
==========
chair a chair 
0.9800534248352051
table chair 
0.3927270174026489
curved table curved chair 
0.580420970916748
flat table flat chair 
0.5569348335266113
flat table curved chair 
0.42730042338371277
==========

clip (classifier_free_guidance_pytorch)
==========
chair a chair
0.6290445327758789
table chair
0.9999410510063171
curved table curved chair
1.0000001192092896
flat table flat chair
1.000000238418579
flat table curved chair
0.9999337196350098
==========

openai/clip-vit-large-patch14
==========
chair a chair
0.9482842683792114
table chair
0.7854955196380615
curved table curved chair
0.823255181312561
flat table flat chair
0.7940264940261841
flat table curved chair
0.6317252516746521
==========

bge:
==========
chair a chair
0.9138880372047424
table chair
0.7104504108428955
curved table curved chair
0.8448339700698853
flat table flat chair
0.8288297057151794
flat table curved chair
0.6892191171646118
==========