Little-Podi / GRM

[CVPR'23] The official PyTorch implementation of our CVPR 2023 paper: "Generalized Relation Modeling for Transformer Tracking".
MIT License
69 stars 8 forks source link

About the 'threshold' inference #6

Closed kun-dragon closed 1 year ago

kun-dragon commented 1 year ago

Could you please explain why certain datasets require a threshold and why there are different thresholds for them during inference? ` if self.training:

During training

            decision = F.gumbel_softmax(divide_prediction, hard=True)
        else:
            # During inference
            if threshold:
                # Manual rank based selection
                decision_rank = (F.softmax(divide_prediction, dim=-1)[:, :, 0] < threshold).long()
            else:
                # Auto rank based selection
                decision_rank = torch.argsort(divide_prediction, dim=-1, descending=True)[:, :, 0]

            decision = F.one_hot(decision_rank, num_classes=2)`
Little-Podi commented 1 year ago

Hi. This threshold controls the amount of search tokens that interact with the template tokens. For some benchmarks, the challenging scenarios make it hard to precisely determine the region to interact, thus we may need to lower the constraint to improve the recall. On the contrary, the division results might be always confident on some other datasets. In those cases, we can be more strict to promote the feature interaction process thus improve the tracking performance. As for a binary classification formulation, the default threshold is 0.5, which I implemented it as "auto rank" here. From my experience, it already works fine enough and minor tuning may improve the overall performance a little.