ibrahimethemhamamci / CT-CLIP

A foundation model utilizing chest CT volumes and radiology reports for supervised-level zero-shot detection of abnormalities
127 stars 13 forks source link

Shape Mis-Match #21

Open chasehmathis opened 3 weeks ago

chasehmathis commented 3 weeks ago
        enc_image = enc_image.view(enc_image.shape[0], -1)
        print(enc_image.shape)

        # early return of encodings, if needed (for DALL-E2)

        if return_encodings:
            return enc_text, enc_image

        # depending on whether to do fine-grained CLIP or not, select either all tokens, or CLS tokens only

        if self.use_all_token_embeds:
            assert enc_text.ndim == 3, 'encoded text must have 3 dimensions (batch, seq, features)'
            assert enc_image.ndim == 3, 'encoded image must have 3 dimensions (batch, seq [height x width], features)'
            text_embeds = enc_text[:, 1:] if self.text_has_cls_token else enc_text
            image_embeds = enc_image[:, 1:] if self.visual_has_cls_token else enc_image
        else:
            text_embeds = enc_text[:, :] if enc_text.ndim == 3 else enc_text
            image_embeds = enc_image[:, :] if enc_image.ndim == 3 else enc_image

        # project to latents
        #text_embeds = text_embeds.view(text_embeds.shape[0], -1)

        text_embeds = text_embeds[:,0,:]

        #text_embeds = torch.mean(text_embeds, dim=1)
        text_latents = self.to_text_latent(text_embeds)

On lines 706-731, I believe there is some issue with the manipulation of tensors handling. If we choose to only select the CLS token then we should not execute this line enc_image = enc_image.view(enc_image.shape[0], -1) and instead take the last (cls) tokens of the tensors:

image_embeds = [:,-1, :] I believe would work better

chasehmathis commented 3 weeks ago

the class tokens are the first entry so it should actually be image_embeds = [:,0, :]

sezginerr commented 2 weeks ago

Hi @chasehmathis, it is correct but our encoder does not have a CLS token. We actually use all tokens to train the CLIP model. We know this is a little redundant and we will push another model with a better approach in the following days. If you use different encoder with CLS token you can use the second approach you sent.