IDEA-Research / detrex

detrex is a research platform for DETR-based object detection, segmentation, pose estimation and other visual recognition tasks.
https://detrex.readthedocs.io/en/latest/
Apache License 2.0
1.97k stars 206 forks source link

Some questions about HungarianMatcher for focal_loss_cost #196

Open powermano opened 1 year ago

powermano commented 1 year ago

In the HungarianMatcher

elif self.cost_class_type == "focal_loss_cost":
    alpha = self.alpha
    gamma = self.gamma
    neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log())
    pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
    cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]

If my dataset only has one category, then the shape of out_prob is [batch_size * num_q, 1], and the neg_cost_class and pos_cost_class have the same shape. When calculating the "focal_loss_cost", we only need use cost_class = pos_cost_class[:, tgt_ids] . But this will not affect the final matching result, because all elements of cost_class are the same.

If my dataset has multiple categories,then the shape of out_prob is [batch_size * num_q, num_classes], and the neg_cost_class and pos_cost_class have the same shape. When calculating the "focal_loss_cost", it should be

cost_class = pos_cost_class[:, tgt_ids] + torch.sum(neg_cost_class, axis=-1,keepdim=True) - neg_cost_class[:, tgt_ids]

Although torch.sum(neg_cost_class, axis=-1) is a fixed value for the category of gt matched by each query, this value is different for different queries.

As mentioned in the DETR paper, if using CE for classification:

 In the matching cost we use probabilities ˆ pˆ σ(i)(ci) instead of log-probabilities. This makes the class prediction term commensurable to Lbox(·,·) (described below), and we observed better empirical performances. 

But i can not find some explanation about why focal loss in the matching cost using cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids].

Can someone give a proper explanation ?

powermano commented 1 year ago

My test code is as follows, and the results are:

num_classes = 5 num_queries = 100 bs = 2

mutil classes

if num_classes > 1: tgt_ids1 = torch.tensor([1,3]) tgt_ids2 = torch.tensor([2,3,4]) else: tgt_ids1 = torch.tensor([0,0]) tgt_ids2 = torch.tensor([0,0,0])

targets = [tgt_ids1, tgt_ids2]

tgt_ids = torch.cat((tgt_ids1, tgt_ids2))

>>> tgt_ids

tensor([1, 3, 2, 3, 4])

alpha = 0.25 gamma = 2.0

out_prob = torch.randn((bs * num_queries, num_classes)).sigmoid()

neg_cost_class = (1 - alpha) * (out_prob*gamma) (-(1 - out_prob + 1e-8).log()) # [bs num_queries, num_classes] pos_cost_class = alpha ((1 - out_prob) * gamma) (-(out_prob + 1e-8).log()) # [bs num_queries, num_classes] cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] # [bs num_queries, len(tgt_ids)

the whole focal loss

cost_class1 = pos_cost_class[:, tgt_ids] + torch.sum(neg_cost_class, axis=-1, keepdim=True) - neg_cost_class[:, tgt_ids] # [bs * num_queries, len(tgt_ids)

C = cost_class # (bs*num_q, len(tgt_ids)) C = C.view(bs, num_queries, -1).cpu() # (bs, num_q, len(tgt_ids))

C1 = cost_class1 # (bs*num_q, len(tgt_ids)) C1 = C1.view(bs, num_queries, -1).cpu() # (bs, num_q, len(tgt_ids))

sizes = [len(v) for v in targets] # (2, 3), if tgt_ids = [[1,3],[2,3,4]] -> [1,3,2,3,4] indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]

print(indices)

indices1 = [linear_sum_assignment(c[i]) for i, c in enumerate(C1.split(sizes, -1))]

print(indices1)

SlongLiu commented 1 year ago

I agree that the focal_cost is not equal to the focal loss. Focal loss:

FL(p)=-(1-p)^{\gamma}y \log (p)-p^{\gamma}(1-y)\log(1-p)

it should be $FL(p)=-(1-p)^{\gamma}y \log (p)$ if $y=1$.

While in the matching:

C(p)= - a(1-p)^{\gamma}\log p + (1-a)p^{\gamma}\log (1-p)

The implementation borrows from the Deformable DETR. I think it is a good question to explore different variants of focal costs.

powermano commented 1 year ago

I agree that the focal_cost is not equal to the focal loss. Focal loss:

FL(p)=−(1−p)γylog⁡(p)−pγ(1−y)log⁡(1−p) it should be FL(p)=−(1−p)γylog⁡(p) if y=1.

While in the matching:

C(p)=−a(1−p)γlog⁡p+(1−a)pγlog⁡(1−p) The implementation borrows from the Deformable DETR. I think it is a good question to explore different variants of focal costs.

Thanks. In my opinion, the implementation from the Deformable DETR aims to make the class prediction term commensurable to L_box. As mentioned In DETR paper: image

In the original DETR implementation, the classification cost is in [0,1], so Deformable DETR wants to control the value within the appropriate range. Maybe, there will be some better variants of focal costs.