orrzohar / FOMO

Official Pytorch code for Open World Object Detection in the Era of Foundation Models
Apache License 2.0
57 stars 4 forks source link

Several Questions #6

Closed YukunLi99 closed 8 months ago

YukunLi99 commented 8 months ago

Thank you for your outstanding work! Could you kindly help me with the following three specific questions: 1)Why is an additional dimension concatenation required in this case?

code: https://github.com/orrzohar/FOMO/blob/main/models/FOMO.py#L351

if args.use_attributes:
    self.att_embeds = torch.cat([self.att_embeds, torch.matmul(self.att_embeds.squeeze().T, self.att_W).mean(1, keepdim=True).T.unsqueeze(0)], dim=1)

2)FOMO needs attributes selected for each category, but the current implementation doesn't guarantee an equal number of attributes selected for each category.

code: https://github.com/orrzohar/FOMO/blob/main/models/FOMO.py#L441

_, top_indices = torch.topk(self.att_W.view(-1), num_classes * self.num_attributes_per_class)

3)Is the training and evaluation process consistent in computing attribute scores? Training Stagewithout learnable parameters logit_shift and logit_scale code for attribute_refinement: https://github.com/orrzohar/FOMO/blob/main/models/FOMO.py#L394 code for attribute_selection: https://github.com/orrzohar/FOMO/blob/main/models/FOMO.py#L431

cos_sim = cosine_similarity(image_embeddings, self.att_embeds, dim=-1)

Eval Stagewith learnable parameters logit_shift and logit_scale *pred_logits = (pred_logits + logit_shift) logit_scale**

code: https://github.com/orrzohar/FOMO/blob/main/models/FOMO.py#L643

(pred_logits, class_embeds) = self.model.class_predictor(image_feats, self.att_embeds.repeat(batch_size, 1, 1),
self.att_query_mask)

def class_predictor(
        self,
        image_feats: torch.FloatTensor,
        query_embeds: Optional[torch.FloatTensor] = None,
        query_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.FloatTensor]:
        """
        Args:
            image_feats:
                Features extracted from the `image_text_embedder`.
            query_embeds:
                Text query embeddings.
            query_mask:
                Must be provided with query_embeddings. A mask indicating which query embeddings are valid.
        """
        (pred_logits, image_class_embeds) = self.class_head(image_feats, query_embeds, query_mask)

        return (pred_logits, image_class_embeds)

class OwlViTClassPredictionHead(nn.Module):
    def __init__(self, config: OwlViTConfig):
        super().__init__()

        out_dim = config.text_config.hidden_size
        self.query_dim = config.vision_config.hidden_size

        self.dense0 = nn.Linear(self.query_dim, out_dim)
        self.logit_shift = nn.Linear(self.query_dim, 1)
        self.logit_scale = nn.Linear(self.query_dim, 1)
        self.elu = nn.ELU()

    def forward(
        self,
        image_embeds: torch.FloatTensor,
        query_embeds: Optional[torch.FloatTensor],
        query_mask: Optional[torch.Tensor],
    ) -> Tuple[torch.FloatTensor]:
        image_class_embeds = self.dense0(image_embeds)
        if query_embeds is None:
            device = image_class_embeds.device
            batch_size, num_patches = image_class_embeds.shape[:2]
            pred_logits = torch.zeros((batch_size, num_patches, self.query_dim)).to(device)
            return (pred_logits, image_class_embeds)

        # Normalize image and text features
        image_class_embeds /= torch.linalg.norm(image_class_embeds, dim=-1, keepdim=True) + 1e-6
        query_embeds /= torch.linalg.norm(query_embeds, dim=-1, keepdim=True) + 1e-6

        # Get class predictions
        pred_logits = torch.einsum("...pd,...qd->...pq", image_class_embeds, query_embeds)

        # Apply a learnable shift and scale to logits
        logit_shift = self.logit_shift(image_embeds)
        logit_scale = self.logit_scale(image_embeds)
        logit_scale = self.elu(logit_scale) + 1
        pred_logits = (pred_logits + logit_shift) * logit_scale

        if query_mask is not None:
            if query_mask.ndim > 1:
                query_mask = torch.unsqueeze(query_mask, dim=-2)

            pred_logits = pred_logits.to(torch.float64)
            pred_logits = torch.where(query_mask == 0, -1e6, pred_logits)
            pred_logits = pred_logits.to(torch.float32)

        return (pred_logits, image_class_embeds)

Your excellent will be a great help to my research!

orrzohar commented 8 months ago

Hi @YukunLi99, I would be happy to help:

1) The concatenation is required so I will have a general "unknown object" logit. This is important as I need to have something to represent the unknown object predictions. I then multiply this logit with the "objectness" ($=p{OOD} \cdot p{ID}$, see here). I set this logit to be the average of all the attributes -- I do this so the scaling of the logit would be broadly similar to that of the other logits (while mathematically [0,1] due to the sigmoid, statistically you will find that they have some mean and std, e.g., 0.8 and 0.04, meaning that simply using a constant value like 0.5/1 making the model harder to calibrate).

2) I don't require the model to select an attribute from each category -- I just use the categories as a convenient way to prompt an LLM for attributes and don't think this is a critical aspect for OWD. I think what is important is getting good attributes that are relevant for the task.

3) Yes it is consistent -- but it is important to note what everything does. In the attribute selection and refinement phase, you are given a set of N object examples (image + bounding box). To extract the corresponding embedding that represents that object is the code you send in the end of your message. Briefly, we input the image in question to the model, get the K corresponding embeddings (after the cls head, I believe this value is >2000 for the large model), where each embedding can represent a different bbox prediction. We then filter out all the predictions that do not sufficiently overlap with the GT bbox, and select the embedding farthest away from the mean. This is where you get the "scale" and and "shift" being accounted for. Once we obtain the different embeddings, we return to our problem -- which is where we use just the normal cosine similarity between the different attribute embeddings and the prediction embedding. To summaries, the "prediction logit" is only used to select the right embedding for the example object we were given. Not in inference/training/evaluation.

Hope this helps! Orr

YukunLi99 commented 8 months ago

Thank you for your reply!

I still have some doubts regarding the first and third questions:

1) Should the computation of $p{OOD}$ and $p{ID}$ be based only on the attributes selected from known classes? The computation of $p{ID}$ includes attributes from a general unknown class, whereas the computation of $p{OOD}$ excludes attributes from this unknown class. Seen here for calculating smlogits ($p{ID}$) and mcm ($p_{OOD}$):

class UnkDetHead(nn.Module):
    def __init__(self, method, known_dims, att_W, **kwargs):
        super(UnkDetHead, self).__init__()
        print("UnkDetHead", method)
        self.method = method
        self.known_dims = known_dims
        self.att_W = att_W
        self.process_mcm = nn.Softmax(dim=-1)

        if "sigmoid" in method:
            self.process_logits = nn.Sigmoid()
            self.proc_obj = True
        elif "softmax" in method:
            self.process_logits = nn.Softmax(dim=-1)
            self.proc_obj = True
        else:
            self.proc_obj = False

    def forward(self, att_logits):
        logits = att_logits @ self.att_W
        k_logits = logits[..., :self.known_dims]
        unk_logits = logits[..., self.known_dims:].max(dim=-1, keepdim=True)[0]
        logits = torch.cat([k_logits, unk_logits], dim=-1)
        objectness = torch.ones_like(unk_logits).squeeze(-1)

        if "mean" in self.method:
            sm_logits = self.process_logits(att_logits)
            objectness = sm_logits.mean(dim=-1, keepdim=True)[0]

        elif "max" in self.method:
            sm_logits = self.process_logits(att_logits)
            objectness = sm_logits.max(dim=-1, keepdim=True)[0]

        if "mcm" in self.method:
            mcm = self.process_mcm(k_logits).max(dim=-1, keepdim=True)[0]
            objectness *= (1 - mcm)

        if self.proc_obj:
            objectness -= objectness.mean()
            objectness /= objectness.std()
            objectness = torch.sigmoid(objectness)

        return logits, objectness.squeeze(-1)

2) I noticed that the forward function of the FOMO model is called during evaluation (see [here](https://github.com/orrzohar/FOMO/blob/a7539685396daaa0e213d9172051b1fc95572c82/engine.py#L36)). During the execution, the forward function of OwlViTClassPredictionHead (self.model.class_predictor seen [here](https://github.com/orrzohar/FOMO/blob/a7539685396daaa0e213d9172051b1fc95572c82/models/FOMO.py#L643)) is still invoked. I'm uncertain if there might be an issue with the execution during evaluation. However, while debugging, I observed that a learnable shift and scale are utilized to adjust the logits.

orrzohar commented 8 months ago

Hi @YukunLi99, That is fine, I am happy to go through any aspect of this work.

1) That is a good point, and I will look into the effect of excluding that last logit from the objectness calculation. Note that $p{OOD}$ is calculated on the logits and $p{ID}$ on the attribute logits. Conceptually, it does not seem to me to be an issue to require that an unknown object be ID to both the known object attributes and their average and OOD w.r.t. the known object logits - but it may be a more clean.

2) I don't think I understand the issue self.embed_image_query returns the selected class_embeds from the same self.model.class_predictor function. These embeddings are then used as the supervision (either directly with l2 loss or via classification during refinement). During forward, the self.model.class_predictor function is called to get predicted embeddings, which are then fed to the unk_head to derive the final class prediction. Notice that in training, we use these derived embeddings and the same forward process as the unk_head uses (multiplying attributes by W). Notice that we can't do anything for unknown objects as we don't have any supervision for them. Note that the parts you mention about 'scaling' are used in a separate problem setting, where you have an image and a bbox and you want to extract the relevant class embedding -- when in normal inference you do not have this information.

If I didn't answer your concerns, let me know and we could perhaps discuss this more. Orr

YukunLi99 commented 8 months ago

Thank you for your detailed response! I still haven't fully understood the last question. My previous question might not have been clear in its wording, and I apologize for that. I understand that during the training phase, to obtain image embeddings for each class based on few-shot samples, the self.embed_image_query function is used, which returns the selected class_embeds from the self.model.class_predictor function. This operation corresponds to the Image-Conditioned Detection of OWL-ViT. During the evaluation phase, the unk_head receives att_logits as input, which is the cosine similarity between the object visual embedding and attribute embedding. And att_logits is obtained from the return value (pred_logits) of the self.model.class_predictor function, as shown [below]:

(pred_logits, class_embeds) = self.model.class_predictor(image_feats, self.att_embeds.repeat(batch_size, 1, 1),
                                                     self.att_query_mask)

        out = OwlViTObjectDetectionOutput(
            image_embeds=image_embeds,
            text_embeds=self.att_embeds,
            pred_boxes=pred_boxes,
            logits=pred_logits,
            class_embeds=class_embeds,
            vision_model_output=vision_outputs,
        )

        out.att_logits = out.logits  #TODO: remove later
        out.logits, out.obj = self.unk_head(out.logits)

For the forward function of OwlViTClassPredictionHead from transformers/models/owlvit/modeling_owlvit.py, since query_embeds (self.att_embeds.repeat(batch_size, 1, 1)) is not None, the following operation is performed (applying a learnable shift and scale to logits):

def forward(
        self,
        image_embeds: torch.FloatTensor,
        query_embeds: Optional[torch.FloatTensor],
        query_mask: Optional[torch.Tensor],
    ) -> Tuple[torch.FloatTensor]:
        image_class_embeds = self.dense0(image_embeds)
        if query_embeds is None:
            device = image_class_embeds.device
            batch_size, num_patches = image_class_embeds.shape[:2]
            pred_logits = torch.zeros((batch_size, num_patches, self.query_dim)).to(device)
            return (pred_logits, image_class_embeds)

        # Normalize image and text features
        image_class_embeds /= torch.linalg.norm(image_class_embeds, dim=-1, keepdim=True) + 1e-6
        query_embeds /= torch.linalg.norm(query_embeds, dim=-1, keepdim=True) + 1e-6

        # Get class predictions
        pred_logits = torch.einsum("...pd,...qd->...pq", image_class_embeds, query_embeds)

        # Apply a learnable shift and scale to logits
        logit_shift = self.logit_shift(image_embeds)
        logit_scale = self.logit_scale(image_embeds)
        logit_scale = self.elu(logit_scale) + 1
        pred_logits = (pred_logits + logit_shift) * logit_scale
orrzohar commented 8 months ago

Hi @YukunLi99,

OK I understand what you mean. I honestly never tried using the scaling during attribute refinement (so there is no 'deep' reason why it is not applied) -- and I actually wonder what would happen if I did do this in refinement. I will try this and let you know if it affects the results in any meaningful way or not.

I personally looked at the 'training' as a separate optimization process where I wanted $\text{cosSim}(A\text{att}, E{cls}) \times W$ to be classified to the right class (and thus benefit from supervision from other classes as well as the class in question).

In any case, note that the scale and shift apply a global shift to the logits (i.e., all logits times some scalar plus some scalar). I think this may be more relevant in the actual forward inference -- as you have some 2000 possible predictions (and notice that each one has a different scalar). Unlike that, when we refine the attribute, each embedding belongs to one class (there are no "no object" embeddings), which is probably the purpose of this shift and scalar (increase the weight of 'objects')

I'll let you know if the scaling does anything to fine-tuning, but I currently have other priorities, but can get to this early Feb.

Best, Orr

orrzohar commented 8 months ago

Hi @YukunLi99, Haven't heard back from you in a few weeks. Closing this issue, feel free to re-open if relevant. Best, Orr