lucidrains / meshgpt-pytorch

Implementation of MeshGPT, SOTA Mesh generation using Attention, in Pytorch
MIT License
700 stars 57 forks source link

MeshTransformer.generate does not work with a prompt if kv cache is enabled #48

Closed Kurokabe closed 8 months ago

Kurokabe commented 8 months ago

I wanted to experiment how the MeshTransformer is able to complete a mesh by giving the initial codes, but there is a problem where I think the prompt codes are not correctly given down the line. Here is a small debug code:

vertices = torch.randn(2, 100, 3)
faces = torch.randint(0, 100, (2, 100, 3))
# gpt = # Load MeshTransformer from checkpoint
codes = gpt.autoencoder.tokenize(vertices=vertices, faces=faces)
generated = gpt.generate(prompt=codes)

It gives the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[421], [line 1](vscode-notebook-cell:?execution_count=421&line=1)
----> [1](vscode-notebook-cell:?execution_count=421&line=1) generated = gpt.generate(prompt=codes)

File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\x_transformers\autoregressive_wrapper.py:27](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/autoregressive_wrapper.py:27), in eval_decorator.<locals>.inner(self, *args, **kwargs)
     [25](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/autoregressive_wrapper.py:25) was_training = self.training
     [26](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/autoregressive_wrapper.py:26) self.eval()
---> [27](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/autoregressive_wrapper.py:27) out = fn(self, *args, **kwargs)
     [28](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/autoregressive_wrapper.py:28) self.train(was_training)
     [29](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/autoregressive_wrapper.py:29) return out

File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\torch\utils\_contextlib.py:115](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/utils/_contextlib.py:115), in context_decorator.<locals>.decorate_context(*args, **kwargs)
    [112](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/utils/_contextlib.py:112) @functools.wraps(func)
    [113](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/utils/_contextlib.py:113) def decorate_context(*args, **kwargs):
    [114](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/utils/_contextlib.py:114)     with ctx_factory():
--> [115](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/utils/_contextlib.py:115)         return func(*args, **kwargs)

File <@beartype(meshgpt_pytorch.meshgpt_pytorch.MeshTransformer.generate) at 0x2239b61c9d0>:170, in generate(__beartype_func, __beartype_conf, __beartype_get_violation, __beartype_object_2352326455808, __beartype_object_2350005736832, __beartype_object_2349955153872, __beartype_object_140723080033008, __beartype_getrandbits, *args, **kwargs)

File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\meshgpt_pytorch\meshgpt_pytorch.py:1238](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1238), in MeshTransformer.generate(self, prompt, batch_size, filter_logits_fn, filter_kwargs, temperature, return_codes, texts, text_embeds, cond_scale, cache_kv, face_coords_to_file)
   [1233](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1233) for i in tqdm(range(curr_length, self.max_seq_len)):
   [1234](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1234)     # v1([q1] [q2] [q1] [q2] [q1] [q2]) v2([eos| q1] [q2] [q1] [q2] [q1] [q2]) -> 0 1 2 3 4 5 6 7 8 9 10 11 12 -> v1(F F F F F F) v2(T F F F F F) v3(T F F F F F)
   [1236](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1236)     can_eos = i != 0 and divisible_by(i, self.num_quantizers * 3)  # only allow for eos to be decoded at the end of each face, defined as 3 vertices with D residual VQ codes
-> [1238](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1238)     output = self.forward_on_codes(
   [1239](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1239)         codes,
   [1240](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1240)         text_embeds = text_embeds,
   [1241](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1241)         return_loss = False,
   [1242](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1242)         return_cache = cache_kv,
   [1243](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1243)         append_eos = False,
   [1244](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1244)         cond_scale = cond_scale,
   [1245](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1245)         cfg_routed_kwargs = dict(
   [1246](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1246)             cache = cache
   [1247](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1247)         )
   [1248](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1248)     )
   [1250](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1250)     if cache_kv:
   [1251](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1251)         logits, cache = output

File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\classifier_free_guidance_pytorch\classifier_free_guidance_pytorch.py:152](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py:152), in classifier_free_guidance.<locals>.inner(self, cond_scale, rescale_phi, cfg_routed_kwargs, *args, **kwargs)
    [148](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py:148) null_fn_kwargs = {k: v[1] for k, v in cfg_routed_kwargs.items()}
    [150](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py:150) # non-null forward
--> [152](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py:152) outputs = fn_maybe_with_text(self, *args, **fn_kwargs, **kwargs_without_cond_dropout)
    [154](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py:154) if cond_scale == 1:
    [155](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py:155)     return outputs

File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\classifier_free_guidance_pytorch\classifier_free_guidance_pytorch.py:130](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py:130), in classifier_free_guidance.<locals>.inner.<locals>.fn_maybe_with_text(self, *args, **kwargs)
    [127](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py:127)     if 'raw_text_cond' in fn_params:
    [128](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py:128)         kwargs.update(raw_text_cond = raw_text_cond)
--> [130](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py:130) return fn(self, *args, **kwargs)

File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\meshgpt_pytorch\meshgpt_pytorch.py:1514](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1514), in MeshTransformer.forward_on_codes(self, codes, return_loss, return_cache, append_eos, cache, texts, text_embeds, cond_drop_prob)
   [1511](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1511) if one_face:
   [1512](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1512)     fine_vertex_codes = fine_vertex_codes[:, :(curr_vertex_pos + 1)]
-> [1514](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1514) attended_vertex_codes, fine_cache = self.fine_decoder(
   [1515](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1515)     fine_vertex_codes,
   [1516](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1516)     cache = fine_cache,
   [1517](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1517)     return_hiddens = True
   [1518](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1518) )
   [1520](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1520) if not should_cache_fine:
   [1521](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1521)     fine_cache = None

File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\torch\nn\modules\module.py:1518](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1518), in Module._wrapped_call_impl(self, *args, **kwargs)
   [1516](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1516)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1517](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1517) else:
-> [1518](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1518)     return self._call_impl(*args, **kwargs)

File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\torch\nn\modules\module.py:1527](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1527), in Module._call_impl(self, *args, **kwargs)
   [1522](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1522) # If we don't have any hooks, we want to skip the rest of the logic in
   [1523](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1523) # this function, and just call forward.
   [1524](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1524) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1525](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1525)         or _global_backward_pre_hooks or _global_backward_hooks
   [1526](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1526)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1527](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1527)     return forward_call(*args, **kwargs)
   [1529](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1529) try:
   [1530](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1530)     result = None

File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\x_transformers\x_transformers.py:1299](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:1299), in AttentionLayers.forward(self, x, context, mask, context_mask, attn_mask, self_attn_kv_mask, mems, seq_start_pos, cache, cache_age, return_hiddens, rotary_pos_emb)
   [1296](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:1296)     x = pre_norm(x)
   [1298](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:1298) if layer_type == 'a':
-> [1299](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:1299)     out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, return_intermediates = True)
   [1300](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:1300) elif layer_type == 'c':
   [1301](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:1301)     out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), return_intermediates = True)

File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\torch\nn\modules\module.py:1518](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1518), in Module._wrapped_call_impl(self, *args, **kwargs)
   [1516](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1516)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1517](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1517) else:
-> [1518](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1518)     return self._call_impl(*args, **kwargs)

File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\torch\nn\modules\module.py:1527](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1527), in Module._call_impl(self, *args, **kwargs)
   [1522](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1522) # If we don't have any hooks, we want to skip the rest of the logic in
   [1523](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1523) # this function, and just call forward.
   [1524](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1524) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1525](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1525)         or _global_backward_pre_hooks or _global_backward_hooks
   [1526](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1526)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1527](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1527)     return forward_call(*args, **kwargs)
   [1529](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1529) try:
   [1530](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1530)     result = None

File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\x_transformers\x_transformers.py:832](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:832), in Attention.forward(self, x, context, mask, context_mask, attn_mask, rel_pos, rotary_pos_emb, prev_attn, mem, return_intermediates, cache)
    [829](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:829)     mk, k = unpack(k, mem_packed_shape, 'b h * d')
    [830](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:830)     mv, v = unpack(v, mem_packed_shape, 'b h * d')
--> [832](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:832) k = torch.cat((ck, k), dim = -2)
    [833](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:833) v = torch.cat((cv, v), dim = -2)
    [835](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:835) if exists(mem):

RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 202 but got size 2 for tensor number 1 in the list.

If I disable however the kv cache with generated = gpt.generate(prompt=codes, cache_kv=False), it works (albeit being slow).

With the cache, in x_transformers > Attention > forward, ck.shape=[202,16,6,64] and k.shape=[2, 16, 1, 64] causing the shape mismatch error (same shapes for cv and v after)

lucidrains commented 8 months ago

@Kurokabe ah ok, so the issue is with prompting and kv cache turned on

you are actually the first one to test mesh prompting! how well is it working without kv cache? can you share what you are seeing?

yea i can fix this later today, in addition to making the quantizer more memory efficient

Kurokabe commented 8 months ago

Thanks for looking into this! 😄

shape_completion_gpt_loss_3 293

Some results on the train set with a gpt trained up to a loss of 3.293. The first row is the ground truth The second row shows the first 100 faces of the ground truth given to the gpt The last row is the gpt output

lucidrains commented 8 months ago

@Kurokabe not bad! thank you Farid! will let you know once i get this fixed tonight 🚀

MarcusLoppe commented 8 months ago

Some results on the train set with a gpt trained up to a loss of 3.293. The first row is the ground truth The second row shows the first 100 faces of the ground truth given to the gpt The last row is the gpt output

Looks very good! Did you train on the 15k dataset and are you prompting through codes?

The loss seem high but I think the paper had metric of 1.4, not sure if that was their transformer loss.

Kurokabe commented 8 months ago

Some results on the train set with a gpt trained up to a loss of 3.293. The first row is the ground truth The second row shows the first 100 faces of the ground truth given to the gpt The last row is the gpt output

Looks very good! Did you train on the 15k dataset and are you prompting through codes?

That's exact, when prompting with the start of the mesh, it's able to create meaningful content, but when tasked to generate from nothing (gpt.generate(temperature=0)) it creates a messy mesh image

MarcusLoppe commented 8 months ago

Some results on the train set with a gpt trained up to a loss of 3.293. The first row is the ground truth The second row shows the first 100 faces of the ground truth given to the gpt The last row is the gpt output

Looks very good! Did you train on the 15k dataset and are you prompting through codes?

That's exact, when prompting with the start of the mesh, it's able to create meaningful content, but when tasked to generate from nothing (gpt.generate(temperature=0)) it creates a messy mesh

I agree, it's much better to give it a little push in the right direction. I've managed to create multi able different objects using the text generation but it needs such a low loss rate. I haven't tried prompting it but I'm guessing makes it so it can generate meshes with relative high training loss.

I wonder how much the cross attention with text helps it at the start, it might be interesting testing the impact of the prompt tokens when using text as well. If the text + prompt give much better results vs just text, it might be worth for @lucidrains to revisit to see if there's any way to increase the impact by the text during the mesh generation.

lucidrains commented 8 months ago

@Kurokabe hey Farid, after a few hours of debugging (hierarchical transformers are confusing), i think i finally figured out the issue

do you want to try 0.5.7 and see if it fixes your original script?

Kurokabe commented 8 months ago

It works perfectly now, thank you for fixing it so fast 🙏

lucidrains commented 8 months ago

@Kurokabe you don't happen to work at one of the startups that has been reaching out to me through email about meshgpt, are you?

Kurokabe commented 8 months ago

@Kurokabe you don't happen to work at one of the startups that has been reaching out to me through email about meshgpt, are you?

What's the name of the company? I didn't hear something about that anyway

lucidrains commented 8 months ago

@Kurokabe ohh nvm, i'm receiving some emails from startups in the 3d space. thought you may had been part of one of them

lucidrains commented 8 months ago

@Kurokabe you are in academia?

Kurokabe commented 8 months ago

@Kurokabe you are in academia?

Kind of, I work at a company atm, but I'll soon start a PhD

MarcusLoppe commented 8 months ago

@lucidrains

Hey,

I tested prompt tokens + text and by only using 1 token it managed to kick-start the generation. So it would seem that the cross-attention of the text have a very weak impact on the first tokens, any ideas around that? If it's possible to increase the impact of the cross-attention to have a stronger relationship with the mesh, then image-to-3d might be possible due a image is a just a bigger vector (kind of).

Text + tokens, using 0 to 5 tokens bild bild

ell-hol commented 8 months ago

@Kurokabe you don't happen to work at one of the startups that has been reaching out to me through email about meshgpt, are you?

@lucidrains I do 🚀