Closed ramiro050 closed 2 years ago
On A100 GPU:
$ python bert.py --device cuda --iters 10 --warmup-iters 5
Compiled iteration times
Median: 7.0052565 ms
10%ile: 6.879468 ms
90%ile: 7.0854697 ms
Total: 69.8995590209961 ms
Eager iteration times
Median: 17.3365375 ms
10%ile: 17.039004300000002 ms
90%ile: 17.6673279 ms
Total: 173.361328125 ms
Compiled result matches eager result!
This commit adds argparse support to bert.py, adding flags to set the device and number of iterations to run the model for. This commit also refactors bert.py.