Open 60wanjinbing opened 2 years ago
with torch.cuda.amp.autocast(): outputs = model(samples, keep_rate) loss = criterion(samples, outputs, targets) 这段代码是在train_one_epoch函数中调用的,你的model没有传token参数,按照你这个应该会报错的,请问你这个token是在哪里传过去的?
你说的tokens参数是用于外面的caller控制剩余token数目的。一般情况下,我们用的是keep_rate这个参数来控制token数目,是在下面这几行计算出来的。 https://github.com/youweiliang/evit/blob/cc1993ddbd49bf3bf84aa39a7488dfdad95ad50a/evit.py#L209-L211
with torch.cuda.amp.autocast(): outputs = model(samples, keep_rate) loss = criterion(samples, outputs, targets) 这段代码是在train_one_epoch函数中调用的,你的model没有传token参数,按照你这个应该会报错的,请问你这个token是在哪里传过去的?