kuai-lab / sound-guided-semantic-image-manipulation

Sound-guided Semantic Image Manipulation - Official Pytorch Code (CVPR 2022)
Other
81 stars 12 forks source link

Computation requirement for training #7

Open HimangiM opened 2 years ago

HimangiM commented 2 years ago

Hi,

First of all, great contribution towards the field of image manipulation. Could you please provide information on how many GPUs and how much duration it took to train the model?

Thanks, Himangi

lsh3163 commented 2 years ago

Dear HimangiM,

Thanks for your interest, it takes one day to train the model with one single GPU because we use the fixed weights of CLIP image encoder and text encoder.

Sincerely,

Allencheng97 commented 2 years ago

Hi, Dear author, Thank you for your great work. May I ask what is your final loss (text_contrastive_loss and image_contrastive_loss) in the final? I tried to train my own model but it seems the loss decrease very slow.

I am looking forward to hearing from you soon.

lsh3163 commented 2 years ago

I agree. The convergence is slow, so I got the best audio representation around 30 epochs. Interpolation between image and text embeddings is also a good option. Here is the code I used:

for idx, (batch_audio, batch_audio_aug, batch_img, batch_text) in enumerate(train_dataloader):
   audio_embedding = audioencoder(batch_audio.cuda())
   audio_aug_embedding = audioencoder(batch_audio_aug.cuda())
   text_tokens = torch.cat([clip.tokenize(text) for text in batch_text])
   with torch.no_grad():
      text_embedding = clip_model.encode_text(text_tokens.to(device))
      text_embedding = text_embedding / text_embedding.norm(dim=-1, keepdim=True)
      image_embedding = clip_model.encode_image(batch_img.to(device))
      image_embedding = image_embedding / image_embedding.norm(dim=-1, keepdim=True)

   audio_embedding = audio_embedding / audio_embedding.norm(dim=-1, keepdim=True)
   audio_aug_embedding = audio_aug_embedding / audio_aug_embedding.norm(dim=-1, keepdim=True)

   loss = 0

   projection_audio_text = (audio_embedding @ text_embedding.T) * math.exp(0.07)
   projection_audio_img = (audio_embedding @ image_embedding.T) * math.exp(0.07)
   projection_self_audio = (audio_embedding @ audio_aug_embedding.T) * math.exp(0.07)

   label = torch.arange(args.batch_size, dtype=torch.long).cuda()

   audio_contrastive_loss = ce(projection_audio_text, label) + ce(projection_audio_text.T, label) + ce(projection_audio_img, label) + ce(projection_audio_img.T, label)
   self_contrastive_loss = ce(projection_self_audio, label) + ce(projection_self_audio.T, label)
   loss = (audio_contrastive_loss + self_contrastive_loss) / 4