salesforce / ULIP

BSD 3-Clause "New" or "Revised" License
424 stars 41 forks source link

Code for cross-modal retrieval #29

Open YuerGu opened 1 year ago

YuerGu commented 1 year ago

hi, i am interested in the code for cross-modal retrieval ? can you share with us?

Tycho-Xue commented 1 year ago

Hi @YuerGu , thanks for your interest in our work! I will upload it soon.

LiXinghui-666 commented 1 year ago

Hi @Tycho-Xue , would it be convenient for you to upload the code about multimodal retrieval task? I'm very interested

Tycho-Xue commented 1 year ago

Hi @LiXinghui-666, sorry for the delay, I've been a bit short of bandwidth to clean that part of the code recently; to unblock you now, please refer to the snippet below as a reference, the retrieval part is actually very straightforward and easy to implement, you just need to modify the zero-shot classification code from matching to the class names to the whole dataset you want to retrieve. Let me know if it still blocks you, I apologize for the inconvenience. this is an example of image-to-point cloud retrieval on modelnet40:


with torch.no_grad():
    # generate image feature
    image_feats = []
    target_all = []
    image_pc_similarity_logits_all_images = []
    for idx, image in enumerate(images):
        image_feat = model.encode_image(image[None, ...].cuda())
        image_feat = image_feat / image_feat.norm(dim=-1, keepdim=True)
        image_feats.append(image_feat)
        for i, (pc, target, target_name) in enumerate(test_loader):
            if idx == 0:
                target_all.extend(target_name)

            pc = pc.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)

            # encode pc
            pc_features = utils.get_model(model).encode_pc(pc)
            pc_features = pc_features / pc_features.norm(dim=-1, keepdim=True)

            image_pc_similarity_batch = pc_features @ image_feat.t()
            image_pc_similarity.append(image_pc_similarity_batch.squeeze())

            if i % args.print_freq == 0:
                progress.display(i)

        image_pc_similarity_logits = torch.cat(image_pc_similarity, dim=0)

        image_pc_similarity_logits_all_images.append(image_pc_similarity_logits)

    image_pc_similarity_logits_all_images_ensemble = torch.stack(image_pc_similarity_logits_all_images, dim=0).max(dim=0, keepdim=True)[0]
    topk_indices = image_pc_similarity_logits_all_images_ensemble.topk(args.topk)[1].cpu().numpy()
    target_all = np.array(target_all)
    topk_classes = target_all[topk_indices]
    topk_classes_logits = image_pc_similarity_logits[topk_indices]

    progress.synchronize()
    print("topk classes are:")
    print(topk_classes)
    print('topk logits are:')
    print(topk_classes_logits)
githubthunder commented 1 year ago

@Tycho-Xue HI,I have tried to conduct the experiment of Cross-Modal Retrieval in the paper "ULIP: Learning a Unified Representation of Language, Images, and Point Clouds for 3D Understanding".

The experimental settings are as follows,

  1. query: all the images with the label "airplane" from the Caltech101 dataset (~800)
  2. retrieval result: point clouds in the ModelNet40 test dataset (~2500)
  3. the pre-trained model: PointBERT_ULIP-2 to extract the embeddings of image and point clouds
  4. retrieve top-10 point clouds for each image.

The results are not as good as expected. Some results are as follows,

Did I miss something? Have you encoutered the same problem?

Tycho-Xue commented 1 year ago

@githubthunder We haven't tried ULIP-2 for image-3d retrieval yet, we did try with ULIP-1 and it worked well, we didn't put up the quantitative results for image-3d retrieval because we are not able to find a proper benchmark for this task; what you showed here does look weird as the results, I might need more context to figure out why it does not function well for you.