mingyuan-zhang / MotionDiffuse

MotionDiffuse: Text-Driven Human Motion Generation with Diffusion Model
850 stars 74 forks source link

About text mask #32

Open ZeyuLing opened 1 year ago

ZeyuLing commented 1 year ago

Thanks a lot for your paper and code! In your implementation, you didn't set attention mask for text sequence both in textTransformer layers and LinearTemporalCrossAttention layers, why it didn't cause any influence? Below is the related code.

def encode_text(self, text, device): with torch.no_grad(): text = clip.tokenize(text, truncate=True).to(device) x = self.clip.token_embedding(text).type(self.clip.dtype) # [batch_size, n_ctx, latent_dim] x = x + self.clip.positional_embedding.type(self.clip.dtype) x = x.permute(1, 0, 2) # NLD -> LND x = self.clip.transformer(x) x = self.clip.ln_final(x).type(self.clip.dtype)

T, B, D

x = self.text_pre_proj(x)
**xf_out = self.textTransEncoder(x)**
xf_out = self.text_ln(xf_out)
xf_proj = self.text_proj(xf_out[text.argmax(dim=-1), torch.arange(xf_out.shape[1])])
# B, T, D
xf_out = xf_out.permute(1, 0, 2)
return xf_proj, xf_out

class LinearTemporalCrossAttention(nn.Module):\

  def __init__(self, seq_len, latent_dim, text_latent_dim, num_head, dropout, time_embed_dim):
      self.num_head = num_head
      self.norm = nn.LayerNorm(latent_dim)
      self.text_norm = nn.LayerNorm(text_latent_dim)
      self.query = nn.Linear(latent_dim, latent_dim)
      self.key = nn.Linear(text_latent_dim, latent_dim)
      self.value = nn.Linear(text_latent_dim, latent_dim)
      self.dropout = nn.Dropout(dropout)
      self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout)

  def forward(self, x, xf, emb):
      x: B, T, D
      xf: B, N, L
      B, T, D = x.shape
      N = xf.shape[1]
      H = self.num_head
      # B, T, D
      query = self.query(self.norm(x))
      # B, N, D
      key = self.key(self.text_norm(xf))
      query = F.softmax(query.view(B, T, H, -1), dim=-1)
      key = F.softmax(key.view(B, N, H, -1), dim=1)
      # B, N, H, HD
      value = self.value(self.text_norm(xf)).view(B, N, H, -1)
      # B, H, HD, HD
      attention = torch.einsum('bnhd,bnhl->bhdl', key, value)
      y = torch.einsum('bnhd,bhdl->bnhl', query, attention).reshape(B, T, D)
      y = x + self.proj_out(y, emb)
      return y