youweiliang / evit

Python code for ICLR 2022 spotlight paper EViT: Expediting Vision Transformers via Token Reorganizations
Apache License 2.0
170 stars 19 forks source link

forward_features(self, x, keep_rate=None, tokens=None, get_idx=False) #12

Open 60wanjinbing opened 2 years ago

60wanjinbing commented 2 years ago

with torch.cuda.amp.autocast(): outputs = model(samples, keep_rate) loss = criterion(samples, outputs, targets) 这段代码是在train_one_epoch函数中调用的,你的model没有传token参数,按照你这个应该会报错的,请问你这个token是在哪里传过去的?

youweiliang commented 2 years ago

你说的tokens参数是用于外面的caller控制剩余token数目的。一般情况下,我们用的是keep_rate这个参数来控制token数目,是在下面这几行计算出来的。 https://github.com/youweiliang/evit/blob/cc1993ddbd49bf3bf84aa39a7488dfdad95ad50a/evit.py#L209-L211