Open dancingpipi opened 4 years ago
I add some code in online_demo/main.py as below:
def main(model_path, d_path): net = MobileNetV2(n_class=27) net.load_state_dict(torch.load(model_path)) transform = get_transform() shift_buffer = [torch.zeros([1, 3, 56, 56]), torch.zeros([1, 4, 28, 28]), torch.zeros([1, 4, 28, 28]), torch.zeros([1, 8, 14, 14]), torch.zeros([1, 8, 14, 14]), torch.zeros([1, 8, 14, 14]), torch.zeros([1, 12, 14, 14]), torch.zeros([1, 12, 14, 14]), torch.zeros([1, 20, 7, 7]), torch.zeros([1, 20, 7, 7])] fnames = os.listdir(d_path) for fname in sorted(fnames): fpath = os.path.join(d_path, fname) image = cv2.imread(fpath) image = transform([Image.fromarray(image).convert('RGB')]) image = torch.autograd.Variable(image.view(1, 3, image.size(1), image.size(2))) p, *shift_buffer = net(image, *shift_buffer) idx = torch.argmax(p.squeeze()).item() code.interact(local = locals())
I test in jester data, but all the pred idx is 21. Could help me solve the problem? wish...
Try to add net.eval() to switch to eval mode
net.eval()
oh my god, I forget it ! I will try it!
Hi, did you resolve the issue with the above suggestion? Thanks!
I add some code in online_demo/main.py as below:
I test in jester data, but all the pred idx is 21. Could help me solve the problem? wish...