lucidrains / meshgpt-pytorch

Implementation of MeshGPT, SOTA Mesh generation using Attention, in Pytorch
MIT License
700 stars 57 forks source link

[Critical] Very high loss rate at first few tokens (classifier free guidance not working) #80

Closed MarcusLoppe closed 2 months ago

MarcusLoppe commented 4 months ago

@lucidrains This is a issue I'm having a while, the cross-attention is very weak at the start of the sequence. When the transformer starts with no tokens it will relay on the cross-attention but unfortunately the cross-attention doesn't work for the first token(s).

Proof

To prove this I trained a dataset of 500 models that have unique text embeddings and no augmentations, then I only took the first 6 tokens of the mesh and train on that. After training for 8hrs, it's still stuck at 1.03 loss.

Without fixing this issue, the auto-regression without a prompt of tokens will never work.

This problem has been ongoing for a while but I thought it was a issue of training and using a model that has been trained on the first few tokens would resolve this. However that isn't the case. Real-life example To highlight the issue, I trained a model on the 13k dataset then removed all the augmentation copies and removed models with duplicate labels. If I provide it with the first 2 tokens as a prompt it will autocomplete without no problem and no visual issues, however if i provide it with 1 or 0 tokens it fails completely.

Checked the logits

I investigated this further and checked the logits when it generated the first token, the probability for correct token was at the 9th most probable token. I tried to implement a beam search with beam width of 5 but since the first token has such a low probability, it would require a lot of beams which probably will work but this seems like a brute force solution isn't very good. It may work to do a beam search of 20 and then kill of the solutions which seems to have a low probability/entropy, but this seems like a bandage solution that might not work with scaling up meshgpt.

Why is this a problem?

The first tokens are very important for the generation since it's a domino effect, if it gets the incorrect token at the start, the generation will fail since it relays to much on the sequence to auto-correct. It's like if the sentence is "Dog can be a happy animal" and it predicts "Human" as the first token, it won't be able to auto-correct since sentence is already messed up and the chances it will auto-correct to "Human got a dog which can be a happy animal" is extremely hard.

Possible solution

Since the cross-attention is used only on the "big" decoder, can it also be implemented for the fine decoder?

Attempts to fix:

This has been a problem for a long time and I've mentioned in the issues threads as a note so I'm creating a issue for it since it really prevents me from releasing fine-tuned models.

I got a model ready to go that can predict 13k models but since the first tokens make the autoregressive generation makes it impossible, I've not released it yet.

Here is some images over the loss: bild

pathquester commented 4 months ago

This sounds critical indeed. Hopefully it's an easy fix.

MarcusLoppe commented 4 months ago

@lucidrains

I think I've resolved this issue by tokenizing the text and insert it at the start of the codes and add a special token to indicate the start of the mesh tokens. However the downside with this is that the transformer needs to use a larger vocab, any idea how if it's possible to reduce the vocab size it's predicting for?

I tested it on a smaller dataset but it seems to be working! I think this will also guide the transformer much better.

bild

pathquester commented 4 months ago

@MarcusLoppe That is fantastic! Have you posted the fix somewhere?

MarcusLoppe commented 4 months ago

@MarcusLoppe That is fantastic! Have you posted the fix somewhere?

Not yet, my current way is bit hacky and requires bit of a rewrite to properly implement.

I'm currently verifying the solution on bit bigger dataset and will hammer out all the possible bugs.

lucidrains commented 4 months ago

@MarcusLoppe hey Marcus, thanks for identifying this issue

have you tried turning off CFG? if you haven't, one thing to try is simply turning off CFG for the first few tokens. i think i've come across papers where they studied at which steps CFG is even effective

also try turning off CFG and do greedy sampling and see what happens. if that doesn't work, there is some big issue

MarcusLoppe commented 4 months ago

@MarcusLoppe hey Marcus, thanks for identifying this issue

have you tried turning off CFG? if you haven't, one thing to try is simply turning off CFG for the first few tokens. i think i've come across papers where they studied at which steps CFG is even effective

also try turning off CFG and do greedy sampling and see what happens. if that doesn't work, there is some big issue

With CFG you mean classifier-free guidance?

Not sure how I would go about that, do you mean setting cond_drop_prob to 0.0? I've tried that and as far as I can tell the CFG just returns the embedding without any modifications (if cond_drop_prob is set to 0 since then it won't mask the text embedding).

The issue lies with when the transformer has a empty sequence and only the text embedding to go from. The text embedding doesn't seem to help very much so it doesn't know what token to pick, hence the huge loss at the start.

lucidrains commented 4 months ago

@MarcusLoppe oh, maybe it is already turned off

so CFG is turned on by setting cond_scale > 1. when invoking .generate

if you haven't been using cond_scale, then perhaps it was never turned on

lucidrains commented 4 months ago

@MarcusLoppe oh crap, do i only append the start token for the fine transformer?? šŸ¤¦ yes you are correct, it is never conditioned then for the first set of fine tokens

lucidrains commented 4 months ago

thank you, this is a real issue then. i'll add cross attention to the fine transformer later today

edit: another cheap solution would be to project text embeddings, pool it, and just use that as the initial sos token

MarcusLoppe commented 4 months ago

@MarcusLoppe oh crap, do i only append the start token for the fine transformer?? šŸ¤¦ yes you are correct, it is never conditioned then for the first set of fine tokens

Awesome, however my 'fix' seems to be working however. By provide the text in the form of tokens in the sequence the fine-decoder will get the text context and it also helps creating a stronger relationship with the tokens and speed up the training. So the tokens it trains on is like: "chair XXXXXXXX" (where X is the mesh tokens).

The downside is that it needs a bigger vocab which slows the training bit but the stronger relationship between the mesh tokens and the text seems to be working :)

thank you, this is a real issue then. i'll add cross attention to the fine transformer later today

I had some issues with proving the context to the fine-decoder since the vector changes shapes but you might be able to solve it.

However I tried removing the gateloop and fine-decoder so the main decoder is the last layer, but unfortunately it had the same issue.

lucidrains commented 4 months ago

@MarcusLoppe yup, your way is also legit šŸ˜„

you have a bright future Marcus! finding this issue, the analysis, coming up with your own solution; none of that is trivial

MarcusLoppe commented 4 months ago

@MarcusLoppe yup, your way is also legit šŸ˜„

you have a bright future Marcus! finding this issue, the analysis, coming up with your own solution; none of that is trivial

Thank you very much šŸ˜„ Although it took a while I think I've learned one or two things on the way šŸ˜„

thank you, this is a real issue then. i'll add cross attention to the fine transformer later today

edit: another cheap solution would be to project text embeddings, pool it, and just use that as the initial sos token

I don't think the cross-attention will be enough, as per my last reply i removed the fine-decoder and gateloop and had the same issue.

If you think about the multimodal generative models they never start from token 0. For a example the vision models has a prompt with a specific request from the user. So it has the first few tokens and some sort of goal or idea what to generate, then the cross-attention will do it's job and provide the addition context. So the generative has a more 'probabilistic path' start to get to the correct answer.

I think projecting the text embeddings might be the better way in this case.

lucidrains commented 4 months ago

@MarcusLoppe yup i went with the pooled text embedding summed to the sos token for now

let me know if that fixes things (or not) šŸ¤ž

lucidrains commented 4 months ago

this was really my fault for designing the initial architecture incorrectly

the sos token should be on the coarse transformer

MarcusLoppe commented 4 months ago

@MarcusLoppe yup i went with the pooled text embedding summed to the sos token for now

let me know if that fixes things (or not) šŸ¤ž

Awesome! I'll check it out šŸš€

However with the last x-transformers update I'm getting the error below. The num_mem_kv doesn't seem to be picked up or trimmed by: "attn_kwargs, kwargs = groupby_prefix_andtrim('attn', kwargs)"

And the dim_head in meshgpt isn't being passed correctly as it should be: "attn_dim_head "

-> 1057 assert len(kwargs) == 0, f'unrecognized kwargs passed in {kwargs.keys()}' 1059 dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) 1061 self.dim = dim

AssertionError: unrecognized kwargs passed in dict_keys(['dim_head', 'num_mem_kv'])

lucidrains commented 4 months ago

@MarcusLoppe ah yes, those should have attn_ prepended, should be fixed in the latest version

MarcusLoppe commented 3 months ago

@lucidrains

Alright here is some results. Using the CLIP embedding model (higher distances in the embedding space) with a GPT-small size transformer:

I first trained using a small set of 350 models, which have a total of x5 augments each. It only contains 39 unique labels so there are some overlap with the texts. Previous test just produced a blob of triangles, this time it outputted all tents and a blob. bild

I then took the same model and removed all augmentations so it's x1 of each model and unique texts for each model. This outputted somewhat better results but it's still not following the text guidance. I checked the logits and the first token generate was for a bench model and the correct was at the 19th placement and had the value 0.013. bild

And as you can see, the loss at the start didn't show any improvements :/ bild

For sanity check I trained a fresh model on the x1 to 0.004 loss but as you can see it didn't help. Might made it worse. bild

I did the same test previously using my method with tokenized text I was able to get all perfect results using the x1 (did not test x5), so that would indicate that the issue that the cross attention relationship when there is no tokens isn't strong enough.

Btw I tested just adding fake tokens by increasing the codebook and used e.g codebook_size +1 (eos at +2) at the start but that didn't change anything.

lucidrains commented 3 months ago

@MarcusLoppe ok, i threw in cross attention conditioning for fine transformer in 1.2.3

if that doesn't work, i'll just spend some time refactoring so the start token is in the coarse transformer. that would work for sure, as it is equivalent to your solution, except the text tokens do not undergo self attention

lucidrains commented 3 months ago

@MarcusLoppe thanks for running the experiments!

MarcusLoppe commented 3 months ago

@MarcusLoppe ok, i threw in cross attention conditioning for fine transformer in 1.2.3

if that doesn't work, i'll just spend some time refactoring so the start token is in the coarse transformer. that would work for sure, as it is equivalent to your solution, except the text tokens do not undergo self attention

@lucidrains

The loss rate improved much better over the epochs, however it had some downside. Before it generate 100 tokens/s, now it went down to 80 t/s, but I prefer this version much more I think since this will cut down the training speed. Since inference time increased so did the per epoch, using a 2k dataset it went from 02:28 to 02:42, however I saw better loss improvements.

Unfortunately it did not work :( However something to note is that it worked before using the demo mesh dataset that consist of 9 meshes.

Cond_scale 1: bild Cond_scale 3: bild

bild

lucidrains commented 3 months ago

@MarcusLoppe ah, thank you

ok, final try

will have to save this for late next week if it doesn't work

MarcusLoppe commented 3 months ago

@MarcusLoppe ah, thank you

ok, final try

will have to save this for late next week if it doesn't work

It worked better, here is the result of training it on 39 models with unique labels, however you can still see a spike in the start of the sequence meaning that it might not be resolved.

bild bild

Using my method I managed to get these results below, it manages to generate quite complex objects. However the start is still bit weak, it would help if you manage to make it so the sos token is in the coarse transformer, this will help the training time a lot since it can reduce the vocab size from 32k to 2k :)

I've also experiment with using 3 tokens per triangle and the autoencoder seems to be working, however it makes the training progression for the transformer slower. But considering that VRAM requirement for training on 800 triangle meshes would go from 22GB to 9GB and half the generation time, I think that is something worth exploring.

However I think that the autoencoder could also benefit from getting the text embeddings, I tried to pass it as the context in the linear attention layer but since it requires the the same shape as the quantized input it won't accept it nor I think it would be very VRAM friendly to duplicate the text embedding to the number of faces. Do you think there is a easy fix for this? I think it would reduce the codebook size a lot and help create codes with closer relationships to the text which would benefit the transformer a lot.

bild

lucidrains commented 3 months ago

@MarcusLoppe that is much better (an order of magnitude)! thank you for the experiments Marcus!

i know how to improve it (can add multiple sos to give the attention more surface area)

lucidrains commented 3 months ago

@MarcusLoppe i'll get back to this later this week šŸ™

lucidrains commented 3 months ago

@MarcusLoppe oh, the sos token has already been moved to the coarse transformer in the latest commit. that's where the improvement you are seeing is coming from

MarcusLoppe commented 3 months ago

@MarcusLoppe that is much better (an order of magnitude)! thank you for the experiments Marcus!

i know how to improve it (can add multiple sos to give the attention more surface area)

Oh awesome, however the loss got very low (0.006) for these results, for the bigger datasets the loss gets to about 0.01 until it needs like 1000 epochs to reach similar loss.

So some further improvements would be nice! :smile:
Any thoughts about the text embedding aware auto-encoder?

lucidrains commented 3 months ago

@MarcusLoppe yup, we can try multiple sos tokens, then if that doesn't work, i'll build in the option to use prepended text embeddings (so like the solution you came up with, additional sos excised or pooled before fine transformer)

and yes, text embedding aware autoencoder is achievable! in fact, the original soundstream paper did this

MarcusLoppe commented 3 months ago

@MarcusLoppe yup, we can try multiple sos tokens, then if that doesn't work, i'll build in the option to use prepended text embeddings (so like the solution you came up with, additional sos excised or pooled before fine transformer)

and yes, text embedding aware autoencoder is achievable! in fact, the original soundstream paper did this

Alright, I've tested the latest patch.

I tested using sos tokens in the amount of: 1,2,4,8, 16, however I was unable to get any usable meshes from it. I sanity checked by reverting to the previous commit and was able to generate valid mesh. To generate the mesh I tested setting the cond_scale to 3, turned off the kv_cache but they just outputted a blob.

However the loss definitely is smoothed over, as you an see the loss doesn't sticks up as it did before (I've might had a too small sample size so ignore the loss values).

I think the issue might be that when they get averaged together they lose their meaning or they get to complex to understand.

Also the below prevented me from testing using 1 sos token, since it get's packed but never unpacked.

        if exists(cache):
            cached_face_codes_len = cached_attended_face_codes.shape[-2]
            cached_face_codes_len_without_sos = cached_face_codes_len - 1

            need_call_first_transformer = face_codes_len > cached_face_codes_len_without_sos
        else:
            # auto prepend sos token 
            sos = repeat(self.sos_token, 'n d -> b n d', b = batch)
            face_codes, packed_sos_shape = pack([sos, face_codes], 'b * d')

            # if no kv cache, always call first transformer 
            need_call_first_transformer = True

         if need_call_first_transformer:
                    if exists(self.coarse_gateloop_block):
                        face_codes, coarse_gateloop_cache = self.coarse_gateloop_block(face_codes, cache = coarse_gateloop_cache)
        ..............
        ..............

        if not exists(cache) and self.num_sos_tokens > 1:
            sos_tokens, attended_face_codes = unpack(attended_face_codes, packed_sos_shape, 'b * d')
            pooled_sos_token = reduce(sos_tokens, 'b n d -> b 1 d', 'mean')
            attended_face_codes = torch.cat((pooled_sos_token, attended_face_codes), dim = 1)

Previous commit ( https://github.com/lucidrains/meshgpt-pytorch/commit/34f2806cccc759b9d00198858cd6019a61985fc3) bild

New commit:

num_sos_tokens = 3 bild bild

num_sos_tokens = 4 bild num_sos_tokens = 8 bild num_sos_tokens = 16 bild

lucidrains commented 3 months ago

@MarcusLoppe ok, let's go with your intuition and just grab the last sos token

MarcusLoppe commented 3 months ago

@MarcusLoppe ok, let's go with your intuition and just grab the last sos token

Don't listen to me :) I think you are onto something, I don't think that it's possible for all the nuances in a text can be contained in a single token. As you can see, the loss is smoother and not sticking up, so it did something right.

Do you have any good reason why you used mean pooling? Otherwise I'll do some testing with replacing it with some attention layer

lucidrains commented 3 months ago

oh I actually kept the multiple sos tokens, but listened to your suggestion not to use mean pooling, and instead grab the last sos token to forward to fine transformer

lucidrains commented 3 months ago

was just reading a paper claiming that turning off CFG for earlier tokens leads to better results https://arxiv.org/html/2404.13040v1 should get this into the CFG repo at some point šŸ¤”

MarcusLoppe commented 3 months ago

oh I actually kept the multiple sos tokens, but listened to your suggestion not to use mean pooling, and instead grab the last sos token to forward to fine transformer

Oh alright, well I've haven't checked out that patch yet but the attention work really good! šŸ˜„ šŸš€
I trained on 350 models and it seems like it have a very strong text conditional.

The only change I did was to add a simple linear layer! :) This might have been the last step of meshgpt! (not counting the last 10 "last steps") I'll get back to you with the results of just using the last sos token.

Using the texts: 'bed', 'sofa', 'monitor', 'bench', 'chair', 'table' I generate them 3 times, first row is with temp at 0.0 and 0.5 for the others. As you can see, it has a very strong text relationship :)

bild

lucidrains commented 3 months ago

@MarcusLoppe awesome! did you mean that you used an extra linear on the mean pooled sos tokens?

MarcusLoppe commented 3 months ago

@MarcusLoppe awesome! did you mean that you used an extra linear on the mean pooled sos tokens?

Hey, so I added a Linear and pooled the tokens as below:

self.attention_weights = nn.Linear(dim, 1, bias=False)

attention_scores = F.softmax(self.attention_weights(sos_tokens), dim=1)
pooled_sos_token = torch.sum(attention_scores * sos_tokens, dim=1, keepdim=True)

It worked very good when I dealt with 350 objects but when I scaled up to 2k objects it didn't work as good šŸ˜­ I tried implementing more then one layer but it got worse when I made it too complex. Also I'm not quite sure about doing mean over the tokens since that might promote similar tokens.

Do you have any other idea then to prepend the tokenized text? šŸ˜„ It really hurts the performance if it needs to predict over 66k (50k+16k) tokens.

Here is some results, I took 355 samples and for each sample I found other items in the dataset that had the same type of label e.g. "chair" and "tall chair". This would give me a good idea about how well it keeps to the text condition.

As you can see, the attention one had the best results using 2 or 4 tokens, however the commit you did before also had some good results.

Latest commit using 4 sos tokens try not pooling all the sos tokens and instead using the last one Token 1: Correct = 169, Incorrect = 186 Token 2: Correct = 234, Incorrect = 121 Token 3: Correct = 353, Incorrect = 2 Token 4: Correct = 354, Incorrect = 1

The commit before the last one move the concatenation of the sos token so it is always Token 1: Correct = 258, Incorrect = 97 Token 2: Correct = 215, Incorrect = 140 Token 3: Correct = 354, Incorrect = 1 Token 4: Correct = 354, Incorrect = 1

All below, attention linear layer:

1 num_sos_tokens: Token 1: Correct = 74, Incorrect = 281 Token 2: Correct = 205, Incorrect = 150 Token 3: Correct = 352, Incorrect = 3 Token 4: Correct = 354, Incorrect = 1

2 num_sos_tokens: Token 1: Correct = 237, Incorrect = 118 Token 2: Correct = 220, Incorrect = 135 Token 3: Correct = 353, Incorrect = 2 Token 4: Correct = 354, Incorrect = 1

4 num_sos_tokens: Token 1: Correct = 237, Incorrect = 118 Token 2: Correct = 205, Incorrect = 150 Token 3: Correct = 351, Incorrect = 4 Token 4: Correct = 354, Incorrect = 1

8 num_sos_tokens: Token 1: Correct = 126, Incorrect = 229 Token 2: Correct = 220, Incorrect = 135 Token 3: Correct = 351, Incorrect = 4 Token 4: Correct = 354, Incorrect = 1

4 num_sos_tokens using mean:

pooled_sos_token = reduce(sos_tokens, 'b n d -> b 1 d', 'mean')
attention_scores = F.softmax(self.attention_weights(pooled_sos_token ), dim=1)
pooled_sos_token = torch.sum(attention_scores * sos_tokens, dim=1, keepdim=True)

Token 1: Correct = 158, Incorrect = 197 Token 2: Correct = 214, Incorrect = 141 Token 3: Correct = 344, Incorrect = 11 Token 4: Correct = 350, Incorrect = 5 Token 5: Correct = 355, Incorrect = 0

lucidrains commented 3 months ago

@MarcusLoppe awesome! did you mean that you used an extra linear on the mean pooled sos tokens?

Hey, so I added a Linear and pooled the tokens as below:

self.attention_weights = nn.Linear(dim, 1, bias=False)

attention_scores = F.softmax(self.attention_weights(sos_tokens), dim=1)
pooled_sos_token = torch.sum(attention_scores * sos_tokens, dim=1, keepdim=True)

It worked very good when I dealt with 350 objects but when I scaled up to 2k objects it didn't work as good šŸ˜­ I tried implementing more then one layer but it got worse when I made it too complex. Also I'm not quite sure about doing mean over the tokens since that might promote similar tokens.

Do you have any other idea then to prepend the tokenized text? šŸ˜„ It really hurts the performance if it needs to predict over 66k (50k+16k) tokens.

Here is some results, I took 355 samples and for each sample I found other items in the dataset that had the same type of label e.g. "chair" and "tall chair". This would give me a good idea about how well it keeps to the text condition.

As you can see, the attention one had the best results using 2 or 4 tokens, however the commit you did before also had some good results.

Latest commit using 4 sos tokens try not pooling all the sos tokens and instead using the last one Token 1: Correct = 169, Incorrect = 186 Token 2: Correct = 234, Incorrect = 121 Token 3: Correct = 353, Incorrect = 2 Token 4: Correct = 354, Incorrect = 1

The commit before the last one move the concatenation of the sos token so it is always Token 1: Correct = 258, Incorrect = 97 Token 2: Correct = 215, Incorrect = 140 Token 3: Correct = 354, Incorrect = 1 Token 4: Correct = 354, Incorrect = 1

All below, attention linear layer:

1 num_sos_tokens: Token 1: Correct = 74, Incorrect = 281 Token 2: Correct = 205, Incorrect = 150 Token 3: Correct = 352, Incorrect = 3 Token 4: Correct = 354, Incorrect = 1

2 num_sos_tokens: Token 1: Correct = 237, Incorrect = 118 Token 2: Correct = 220, Incorrect = 135 Token 3: Correct = 353, Incorrect = 2 Token 4: Correct = 354, Incorrect = 1

4 num_sos_tokens: Token 1: Correct = 237, Incorrect = 118 Token 2: Correct = 205, Incorrect = 150 Token 3: Correct = 351, Incorrect = 4 Token 4: Correct = 354, Incorrect = 1

8 num_sos_tokens: Token 1: Correct = 126, Incorrect = 229 Token 2: Correct = 220, Incorrect = 135 Token 3: Correct = 351, Incorrect = 4 Token 4: Correct = 354, Incorrect = 1

4 num_sos_tokens using mean:

pooled_sos_token = reduce(sos_tokens, 'b n d -> b 1 d', 'mean')
attention_scores = F.softmax(self.attention_weights(pooled_sos_token ), dim=1)
pooled_sos_token = torch.sum(attention_scores * sos_tokens, dim=1, keepdim=True)

Token 1: Correct = 158, Incorrect = 197 Token 2: Correct = 214, Incorrect = 141 Token 3: Correct = 344, Incorrect = 11 Token 4: Correct = 350, Incorrect = 5 Token 5: Correct = 355, Incorrect = 0

you successfully applied the attention pooling from enformer! :clap: :clap:

thank you for the breakdown, going to default the number of sos tokens to 4 :pray:

MarcusLoppe commented 3 months ago

@MarcusLoppe awesome! did you mean that you used an extra linear on the mean pooled sos tokens?

Hey, so I added a Linear and pooled the tokens as below:

self.attention_weights = nn.Linear(dim, 1, bias=False)

attention_scores = F.softmax(self.attention_weights(sos_tokens), dim=1)
pooled_sos_token = torch.sum(attention_scores * sos_tokens, dim=1, keepdim=True)

It worked very good when I dealt with 350 objects but when I scaled up to 2k objects it didn't work as good šŸ˜­ I tried implementing more then one layer but it got worse when I made it too complex. Also I'm not quite sure about doing mean over the tokens since that might promote similar tokens. Do you have any other idea then to prepend the tokenized text? šŸ˜„ It really hurts the performance if it needs to predict over 66k (50k+16k) tokens. Here is some results, I took 355 samples and for each sample I found other items in the dataset that had the same type of label e.g. "chair" and "tall chair". This would give me a good idea about how well it keeps to the text condition. As you can see, the attention one had the best results using 2 or 4 tokens, however the commit you did before also had some good results. Latest commit using 4 sos tokens try not pooling all the sos tokens and instead using the last one Token 1: Correct = 169, Incorrect = 186 Token 2: Correct = 234, Incorrect = 121 Token 3: Correct = 353, Incorrect = 2 Token 4: Correct = 354, Incorrect = 1 The commit before the last one move the concatenation of the sos token so it is always Token 1: Correct = 258, Incorrect = 97 Token 2: Correct = 215, Incorrect = 140 Token 3: Correct = 354, Incorrect = 1 Token 4: Correct = 354, Incorrect = 1 All below, attention linear layer: 1 num_sos_tokens: Token 1: Correct = 74, Incorrect = 281 Token 2: Correct = 205, Incorrect = 150 Token 3: Correct = 352, Incorrect = 3 Token 4: Correct = 354, Incorrect = 1 2 num_sos_tokens: Token 1: Correct = 237, Incorrect = 118 Token 2: Correct = 220, Incorrect = 135 Token 3: Correct = 353, Incorrect = 2 Token 4: Correct = 354, Incorrect = 1 4 num_sos_tokens: Token 1: Correct = 237, Incorrect = 118 Token 2: Correct = 205, Incorrect = 150 Token 3: Correct = 351, Incorrect = 4 Token 4: Correct = 354, Incorrect = 1 8 num_sos_tokens: Token 1: Correct = 126, Incorrect = 229 Token 2: Correct = 220, Incorrect = 135 Token 3: Correct = 351, Incorrect = 4 Token 4: Correct = 354, Incorrect = 1 4 num_sos_tokens using mean:

pooled_sos_token = reduce(sos_tokens, 'b n d -> b 1 d', 'mean')
attention_scores = F.softmax(self.attention_weights(pooled_sos_token ), dim=1)
pooled_sos_token = torch.sum(attention_scores * sos_tokens, dim=1, keepdim=True)

Token 1: Correct = 158, Incorrect = 197 Token 2: Correct = 214, Incorrect = 141 Token 3: Correct = 344, Incorrect = 11 Token 4: Correct = 350, Incorrect = 5 Token 5: Correct = 355, Incorrect = 0

you successfully applied the attention pooling from enformer! šŸ‘ šŸ‘

thank you for the breakdown, going to default the number of sos tokens to 4 šŸ™

I think you misunderstood me, the results are bad since the first tokens most of the cases was not remotely connected. When I used the exact same label to find valid codes it failed very hard. As you can see, after the 3rd it gets 0-2 incorrect and as the sequence gets longer the better accuracy it has, after the 10th it maybe get 2 incorrect every 10 or so. The showcase of the tests is this, it had 2000% better accuracy at the 3rd token vs 1 or 2 token. The issue is still alive I'm afraid, it basically just throws out a guess at the first tokens.

I only trained on the first 36 tokens so I could speed up the testing, but currently I'm training it on the full sequence so I can show you the result of the generations. I'll post the results later on

lucidrains commented 3 months ago

@MarcusLoppe ah, you aren't referring to the number of sos tokens, but to the token number in the main sequence, my bad

try with a much larger number of sos tokens, say 16 or 32

MarcusLoppe commented 3 months ago

@MarcusLoppe ah, you aren't referring to the number of sos tokens, but to the token number in the main sequence, my bad

try with a much larger number of sos tokens, say 16 or 32

I don't have the figures for them but I tried 16 and got bad results. As you can see using 8 tokens had the worst results. I'll shoot up the test script and get you some hard numbers.

I know that setting up the sos tokens before the decoder and then inserting after the cross attention will create some sort of learnable relationship and I assume that the tokens change with loss. However I don't have any data to back this up but isn't it better to have the tokens be a representation of the text embeddings? If the sequence is 48 tokens the majority of the loss comes after the frist few tokens and will 'shape'/optimize to minimize that loss, meaning that the tokens will adapt to fit itself to work for 98% of the sequence. Sort of like sacrifice the frist wave of soldiers in war to be on a better situation so no other soldiers need to die.

So is it possible to reshape (with any nn) the text embeddings to the dim size and then inserting them at the start of the sequence and then a special token?

lucidrains commented 3 months ago

ok, I'll try one more thing, but if that doesn't work, let's go for your text prefix solution

MarcusLoppe commented 3 months ago

ok, I'll try one more thing, but if that doesn't work, let's go for your text prefix solution

Little bit off topic but I trained a 1 quantize auto-encoder and transformer and good results. It was a little slower progression but I got about 0.03 loss with the transformer. I didn't succeed in generating mesh with 0 tokens but providing 10 tokens it managed to generate mesh :)

So that is a big win, halfing the sequence length and reducing vram requirement from 22 GB to 8 GB in training (800 faces)

MarcusLoppe commented 3 months ago

ok, I'll try one more thing, but if that doesn't work, let's go for your text prefix solution

Hi again.

Here is some failed results:

I was wonder if even the decoder cross-attention layer could handle it alone but with just the decoder layer couldn't handle any part of the sequence. So what thinking with the cross-attention? Do you think the sos token or cross-attention can handle the cold start? Since the issue is with the first 0-3 tokens, would it beneficial to create some kind of embedding space that contains the first 3 tokens and is indexed by text embedding, this way the text embedding provided by the user can be used to find the nearest neighbour. It's not very novel but a good way to at-least kickstart the generation, although the issue might be resolved with scale later on.

The best result I got was with the commit below, however It may just be luck and not a consistent behaviour. The linear attention method had similar results but without the slowness of adding cross-attention to the fine-decoder.

Training many many epochs using add cross attention based text conditioning for fine transformer too Token 1: Correct = 260, Incorrect = 95 Token 2: Correct = 319, Incorrect = 36 Token 3: Correct = 351, Incorrect = 4 Token 4: Correct = 355, Incorrect = 0

Linear layer with 4 sos tokens

if not exists(cache):
            sos_tokens, attended_face_codes = unpack(attended_face_codes, packed_sos_shape, 'b * d')
            attention_scores = F.softmax(self.attention_weights(sos_tokens), dim=1)
            pooled_sos_token = torch.sum(attention_scores * sos_tokens, dim=1, keepdim=True)
            attended_face_codes = torch.cat((pooled_sos_token, attended_face_codes), dim = 1)

Token 1: Correct = 237, Incorrect = 118 Token 2: Correct = 205, Incorrect = 150 Token 3: Correct = 351, Incorrect = 4 Token 4: Correct = 354, Incorrect = 1 Token 5: Correct = 355, Incorrect = 0

lucidrains commented 3 months ago

@MarcusLoppe thank you Marcus! šŸ™ will get a few solutions in soon

i also realized i wasn't caching the cross attention key / values correctly šŸ¤¦ will also get that fixed this morning

MarcusLoppe commented 3 months ago

@MarcusLoppe thank you Marcus! šŸ™ will get a few solutions in soon

i also realized i wasn't caching the cross attention key / values correctly šŸ¤¦ will also get that fixed this morning

Awesome! :smile: Outside of meshgpt have you had success training the decoder and let it generate from cold start with just a embedding before? E.g. train on sequences with 6 tokens and the only input is a embedding that is used in the cross attention for the decoder.

It works kinda good when the dataset is small (<500) , I don't think it's the model size since it can remember 10k models if its prompted with a few tokens.

Btw let me know if I'm doing something wrong but during my testing I just call forward_on_codes and get the logits and get the token by argmax. I'm not sure if this would disable the classifier guidance or not.

MarcusLoppe commented 3 months ago

@MarcusLoppe thank you Marcus! šŸ™ will get a few solutions in soon

i also realized i wasn't caching the cross attention key / values correctly šŸ¤¦ will also get that fixed this morning

Hey again,

So I've noticed some strange behaviour with the cross attention num_mem_kv that might help you resolve the issue. I've previously changed the value before without any noticeable changes.

However using the commit with the fine-decoder cross-attention I found the results below. Setting the num_mem_kv cross attention to 16 seems to be hitting some kind of sweet spot (maybe related to the dataset size).

This made it possible to generate mesh from token 0 since it seems to be hitting the correct tokens, however as you can see the mesh is hardly smooth but at least it's selecting the correct first token! I'm currently training to see if using x5 augmentation of the same dataset will yield any better results since it might be more robust. bild

I also tested fine depth either to 4 or 8 but the effect worsen the performance, same goes with increasing the attn_num_mem_kv to 16.

I also tested using 16 cross_attn_num_mem_kv on all the other solutions you've posted but there was no noticeable changes.

Commit: https://github.com/lucidrains/meshgpt-pytorch/commit/5ef6cbfeaa3c43b67c548d1b11e033069b01f590

8 cross_attn_num_mem_kv
Token 1: Correct = 6, Incorrect = 349
Token 2: Correct = 165, Incorrect = 190
Token 3: Correct = 320, Incorrect = 35
Token 4: Correct = 322, Incorrect = 33
Token 5: Correct = 341, Incorrect = 14 

16 cross_attn_num_mem_kv

Token 1: Correct = 293, Incorrect = 62
Token 2: Correct = 331, Incorrect = 24
Token 3: Correct = 354, Incorrect = 1
Token 4: Correct = 354, Incorrect = 1
Token 5: Correct = 355, Incorrect = 0 

16 cross_attn_num_mem_kv
8 fine_attn_depth 
Token 1: Correct = 233, Incorrect = 122
Token 2: Correct = 189, Incorrect = 166
Token 3: Correct = 321, Incorrect = 34
Token 4: Correct = 313, Incorrect = 42
Token 5: Correct = 342, Incorrect = 13 

32 cross_attn_num_mem_kv

Token 1: Correct = 4, Incorrect = 351
Token 2: Correct = 207, Incorrect = 148
Token 3: Correct = 345, Incorrect = 10
Token 4: Correct = 338, Incorrect = 17
Token 5: Correct = 349, Incorrect = 6 

16 attn_num_mem_kv 
16 cross_attn_num_mem_kv
Token 1: Correct = 5, Incorrect = 350
Token 2: Correct = 205, Incorrect = 150
Token 3: Correct = 353, Incorrect = 2
Token 4: Correct = 355, Incorrect = 0
Token 5: Correct = 355, Incorrect = 0 
MarcusLoppe commented 3 months ago

@MarcusLoppe thank you Marcus! šŸ™ will get a few solutions in soon

i also realized i wasn't caching the cross attention key / values correctly šŸ¤¦ will also get that fixed this morning

Hey, @lucidrains I think I've figured something out, I quite a lot changes but I had success by applying the following:

Plus a few other tricks. The training is also quite specific in regards to masking the text and other factors, if it becomes overtrained then the results are just blobs again.
When the conditions are pretty good the model will always generate a complete shape, not always for what you want but at least it's not a blob. Btw I also manage to train a model using 1 quantizer which reduced the inference time by half (duh :) ).

I wouldn't say this issue is resolved since using a dataset with 1k unique labels, during the generation it will steer towards the most average mesh model according the the text embeddings, you can see this average effect in the second image (cond scale helps sometimes, setting it too high will turn the mesh into a blob). Hopefully this information helps you steer towards a final solution that can be used for a large of amount text labels.

Possible issue / accidental feature

I'm not sure if it's a problem but since I add the sos_token before the main decoder and then adding the text embedding pooling afterwards, it will results in 2 tokens with 'value' is added and with the padding it will be 12 tokens. The first 6 extra tokens are due for the autoregressive and the other 6 is due to the text embedding pool since it's added just before pad_to_length is called.

The results is that 1 token will be replaced/lost due to the right shift since the 2 tokens are added and only the sos_token is removed. So the data between the decoder and fine decoder will be shifted right and the becomes in another order, this might not be a issue for the fine decoder since it's already out of order due to the rearranging and adding the grouped_codes so the shape goes from(b, faces, dim) to (b * (faces+1), (quantizers * vertices_per_face), dim) But if you think of in a linear fashion and ignoring the ML transforming the data, the output would be: <pooled_text_embed> <mesh> <cut> <EOS> <extra tokens> Instead of: <mesh> <EOS> <cut> <extra tokens>

This is just a guess but maybe since the output is over a longer sequence window during (12 tokens in the future instead of 6), it might help with the inference since during training it outputs what it thinks might be after the EOS token. However this output is cut off and doesn't affect the loss so I'm not sure if it matters, I also increased the padding so it's 18 tokens but the performance degraded). I also tested replacing the pooled_text_embed with a Parameter dim but it got worse results so the text embedding does affect the output.

Multi-token prediction

I've been trying to understand how the transformer train and at the end there is always 1 extra face (6 tokens) and then the sequence is cut of so it's 5 tokens remaining. I'm guessing this is done for the autoregression and the EOS token. But I think it can provide a additional effect by extending 'hidden' future tokens and can be used multi-token prediction. I'm not sure about where the masking is applied while training but as a test I increase the amount of codes that was cut off and set 'append_eos' to false to see if it can predict multiple tokens ahead. Nothing fancy as the meta paper and just a weak proof of concept.

Here is some samples after training 15 epochs on the first 12 tokens on 2.8k meshes with 1000 labels: 1 tokens: 0.3990 loss (0.5574 loss without the text embedding pooling) 2 tokens: 0.112 loss 3 tokens: 0.24 loss 4 tokens: 0.1375 loss (woah!) 6 tokens: 0.1826 loss (18th epoch 0.104 loss)

if return_loss:
            assert seq_len > 0
            codes, labels = codes[:, :-number_of_tokens], codes 
.......
embed = embed[:, :(code_len + number_of_tokens)] 

500 labels with 10 models for each label- 2k codebook, number of quantizers: 2 bild

1000 labels with 5 models for each label- 2k codebook, number of quantizers: 2 bild

100 labels with 25 models for each label- 16k codebook, number of quantizers: 1 bild

lucidrains commented 3 months ago

@MarcusLoppe thanks Marcus for the detailed report and glad to hear you found a solution!

i'll keep chipping away at it to strengthen conditioning

next up is to probably add adaptive layer/rms normalization to x-transformers

lucidrains commented 3 months ago

@MarcusLoppe you should see a slight speed bump on inference now, as i now cache the key / values correctly for cross attention

hope i didn't break anything!

MarcusLoppe commented 3 months ago

@MarcusLoppe thanks Marcus for the detailed report and glad to hear you found a solution!

i'll keep chipping away at it to strengthen conditioning

next up is to probably add adaptive layer/rms normalization to x-transformers

Lovely :) I'll test the FILM normalization method and let you know. I tried replacing the the PixelNorm using the film batch normalization on the ResNet but I had mild success, the std for the first token and it's relationship to the label decreased from 12 to 11.

However the sos_token isn't quite how I implemented it, I've had more success in just leaving the sos_token in without unpacking it. I had best success with using the sos_token as per your commit move the concatenation of the sos token so it is always conditioned b. That commit have had far better results rather then:

I tried explaining it before with my tests but I might have not been clear enough.

Here is the implementation I've used https://github.com/MarcusLoppe/meshgpt-pytorch/blob/sos_token_test2/meshgpt_pytorch/meshgpt_pytorch.py

@MarcusLoppe you should see a slight speed bump on inference now, as i now cache the key / values correctly for cross attention

hope i didn't break anything!

I'll give it a go :) Btw I can create a pull request for this but when using 1 quantizer, the rounding down method doesn't work in generation. Currently it's doing: 10 codes / 1 = 10 * 1 = 10 codes. Instead of doing: 10 codes / 3 = 3 * 3 = 9 codes.

So changing the below will made the 1 quantizer generation work. From: round_down_code_len = code_len // self.num_quantizers * self.num_quantizers To: round_down_code_len = code_len // self.num_vertices_per_face * self.num_vertices_per_face

lucidrains commented 3 months ago

@MarcusLoppe thanks Marcus for the detailed report and glad to hear you found a solution! i'll keep chipping away at it to strengthen conditioning next up is to probably add adaptive layer/rms normalization to x-transformers

Lovely :) I'll test the FILM normalization method and let you know. I tried replacing the the PixelNorm using the film batch normalization on the ResNet but I had mild success, the std for the first token and it's relationship to the label decreased from 12 to 11.

However the sos_token isn't quite how I implemented it, I've had more success in just leaving the sos_token in without unpacking it. I had best success with using the sos_token as per your commit move the concatenation of the sos token so it is always conditioned b. That commit have had far better results rather then:

  • Unpacking multiple tokens + packing pooling
  • Repacking single
  • Unpacking multiple tokens and packing last token.

I tried explaining it before with my tests but I might have not been clear enough.

Here is the implementation I've used https://github.com/MarcusLoppe/meshgpt-pytorch/blob/sos_token_test2/meshgpt_pytorch/meshgpt_pytorch.py

@MarcusLoppe you should see a slight speed bump on inference now, as i now cache the key / values correctly for cross attention hope i didn't break anything!

I'll give it a go :) Btw I can create a pull request for this but when using 1 quantizer, the rounding down method doesn't work in generation. Currently it's doing: 10 codes / 1 = 10 * 1 = 10 codes. Instead of doing: 10 codes / 3 = 3 * 3 = 9 codes.

So changing the below will made the 1 quantizer generation work. From: round_down_code_len = code_len // self.num_quantizers * self.num_quantizers To: round_down_code_len = code_len // self.num_vertices_per_face * self.num_vertices_per_face

thanks for reporting the rounding down issue!

and yes, i can cleanup the multiple sos tokens code if not needed. however, by setting just 1 sos token, it should be equivalent to what you deem the best working commit