lucidrains / meshgpt-pytorch

Implementation of MeshGPT, SOTA Mesh generation using Attention, in Pytorch
MIT License
700 stars 57 forks source link

Classifier-Free Guidance, cond_drop_prob=1.0, attn_mask=False: Error!!! #63

Closed fighting-Zhang closed 6 months ago

fighting-Zhang commented 6 months ago

When using text conditions, even if the parameter text_condition_cond_drop_prob is set to 0.25 when initializing the MeshTransformer, it is easy to overlook the cond_drop_prob parameter in MeshTransformer.forward_on_codes(cond_drop_prob = 0.).

_, maybe_dropped_text_embeds = self.conditioner( text_embeds = text_embeds, cond_drop_prob = cond_drop_prob )

Accidentally retaining the default parameter cond_drop_prob = 0. means that text conditions are not properly dropped out during training, which is not conducive to the subsequent use of Classifier Free Guidance. It is recommended that the author sets the default parameter of cond_drop_prob to None. @lucidrains

However, most importantly, when I set MeshTransformer.forward_on_codes(cond_drop_prob = 1.0), the mask for text_embedding is all false, meaning that during the calculation of cross-attention, attn_mask = False, which causes the output to be NaN.

How can the above issues be resolved ???

import torch.nn.functional as F out = F.scaled_dot_product_attention( q, k, v, attn_mask = mask, dropout_p = self.dropout if self.training else 0., is_causal = causal ) It seems that attn_mask cannot be False at all positions !!!

lucidrains commented 6 months ago

@fighting-Zhang thanks for catching the first point, made the fix

as for the second, i think the latest x-transformers should handle all rows being masked out, but i added my favorite extra precautionary measure (appending a few memory key/values to cross attend to)

lucidrains commented 6 months ago

best with your research