uzh-rpg / svit

Official implementation of "SViT: Revisiting Token Pruning for Object Detection and Instance Segmentation"
Apache License 2.0
23 stars 3 forks source link

Question about code #6

Closed GewelsJI closed 3 months ago

GewelsJI commented 4 months ago

Hey, authors,

Thanks for your open sourcing such a nice work. I have a small question on your code:

Why F.gumbel_softmax during training, but torch.argmin during inference?

Hope to receive your response. :)

Best, Daniel.

kaikai23 commented 4 months ago

During training you need to have some random behavior so that when the mask probability is less than 0.5, the mask can still sometimes be True(or 1). During inference it is preferred to have deterministic predictions, so > 0.5 probability produce a True mask, otherwise a False mask.

You can actually use F.gumbel_softmax at inference time as well, with no noticeable impact on accuracy.

GewelsJI commented 3 months ago

Thanks for your quick reply. That's great.

Best, Daniel.

King4819 commented 3 months ago

During training you need to have some random behavior so that when the mask probability is less than 0.5, the mask can still sometimes be True(or 1). During inference it is preferred to have deterministic predictions, so > 0.5 probability produce a True mask, otherwise a False mask.

You can actually use F.gumbel_softmax at inference time as well, with no noticeable impact on accuracy.

I want to ask that why using argmin instead of argmax ? I think the mask true should correspond to larger probability, so it should use argmax ?

Hope to get your response, thanks!

King4819 commented 2 months ago

@kaikai23 Hi, I want to ask that in my experiment, if using argmin at inference stage, the number of keep tokens at ViT final layer will be zero, do you have any suggestion ? Hopes to get your reply, thanks !!!