Closed crawlingcub closed 3 years ago
PyTorch resnet 18 is tested on every CI job, https://github.com/apache/tvm/blob/874ea7a81d91857a2495892c598a8b3b87a6da64/tests/python/frontend/pytorch/test_forward.py#L2030. So I don't expect any accuracy difference.
Can you try evaluating the model that is not serialized to disk? When PyTorch jit models are serialized, PyTorch erase all type information. This issue has caused some problems for us in the past.
Hi, just to clarify, this is a mutant derived from a resnet-18 model, so the model structure is a bit different. We are testing the behavior of tvm when running some simple variants of well-tested models like resnet.
Can you try evaluating the model that is not serialized to disk?
What do you mean by this? The results are same as before I serialized this model to disk.
I meant instead of model = torch.load(sys.argv[1]+'/model.pt')
, create a model directly from a python script. But if you already tried that, something is a bit off indeed.
Can you also try exporting to ONNX and try our ONNX frontend? That would tell if this is a frontend specific issue.
Ok, I will try that out
Hi,
I tried exporting original model to onnx and then running with TVM's ONNX frontend. The results are accurate with onnx: actually exactly similar to what I get with pytorch. So this seems like a bug in the pytorch frontend?
Ok. Can you send me an ONNX file, and if possible the pytorch model source?
Hi,
I have updated the link above to include both the pytorch and onnx model. Regarding the model source, I used the pretrained Resnet-18 model from torchvision and applied some simple mutations on top of it such adding noise to some weights, replacing activation function, and adding a new layer. I can send you a model summary if needed.
I can confirm that model.pt
and a TVM model converted via ONNX give the same output. It is hard to compare two TVM models, one coming from PT frontend and the other from ONNX, since the ONNX model folds batch norm into convolution so there is no batch norm in ONNX model.
Since the differences from resnet18 seem small, and we know that there is no issue with resnet18, maybe you can start with resnet18 and gradually add your changes until the TVM result becomes off? I'm using a script modified from https://github.com/apache/tvm/blob/main/tutorials/frontend/from_pytorch.py to test the accuracy, and there is no need to train the model.
I looked at the trace of changes. It seems replacing one ReLU activation with ELU introduced a big change in accuracy. All changes before that did not affect the results much. Maybe some issue with ELU implementation in pytorch frontend?
Bingo! Fixed in https://github.com/apache/tvm/pull/8699
Thanks for reporting.
Awesome! Thanks for your help!
Hi,
I am getting lower accuracy with TVM when targeting both cuda and cpu as compared to running with a pytorch model. This is a variant of a Resnet-18 model. Find the link to download the model below.
You will have to download the imagenet validation dataset and extract/sort it into a folder. Replace
imagenet/data
with the name of that folder.You can download the model from here, untar, and pass the path to the script below.
Environment:
Code:
Output:
Please let me know if you need more info. Thanks!