aws-neuron / aws-neuron-sdk

Powering AWS purpose-built machine learning chips. Blazing fast and cost effective, natively integrated into PyTorch and TensorFlow and integrated with your favorite AWS services
https://aws.amazon.com/machine-learning/neuron/
Other
450 stars 152 forks source link

neuron-cc puts warnings in compilation of ResNext #319

Closed Tagussan closed 3 years ago

Tagussan commented 3 years ago

Hi,

After the issue Torch-Neuron fails to compile ResNeXt has been resolved, I tried to compile it. However, When I use torch.neuron.trace to RexNeXt model, I get the following messages:

INFO:Neuron:100.00% of all operations (including primitives) (3363 of 3363) are supported
INFO:Neuron:100.00% of arithmetic operations (348 of 348) are supported
INFO:Neuron:All operators are compiled by neuron-cc (this does not guarantee that neuron-cc will successfully compile)
INFO:Neuron:Number of arithmetic operators (pre-compilation) before = 348, fused = 348, percent fused = 100.0%
INFO:Neuron:Compiling function _NeuronGraph$3100 with neuron-cc
INFO:Neuron:Compiling with command line: '/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/bin/neuron-cc compile /tmp/tmp7z9qzdma/graph_def.pb --framework TENSORFLOW --pipeline compile SaveTemps --output /tmp/tmp7z9qzdma/graph_def.neff --io-config {"inputs": {"0:0": [[2, 3, 224, 224], "float32"]}, "outputs": ["Add_241:0"]} --verbose 35'
.[09:45:07] /opt/workspace/DmlcTvm/src/relay/pass/do_replication.cc:66: Replication is not applied to group convolution. Kernel and data input channels do not match[3, 3, 4, 256]  [2, 58, 58, 256]
[09:45:07] /opt/workspace/DmlcTvm/src/relay/pass/do_replication.cc:66: Replication is not applied to group convolution. Kernel and data input channels do not match[3, 3, 4, 256]  [2, 58, 58, 256]
[09:45:07] /opt/workspace/DmlcTvm/src/relay/pass/do_replication.cc:66: Replication is not applied to group convolution. Kernel and data input channels do not match[3, 3, 4, 256]  [2, 58, 58, 256]
[09:45:07] /opt/workspace/DmlcTvm/src/relay/pass/do_replication.cc:66: Replication is not applied to group convolution. Kernel and data input channels do not match[3, 3, 8, 512]  [2, 58, 58, 512]
[09:45:07] /opt/workspace/DmlcTvm/src/relay/pass/do_replication.cc:66: Replication is not applied to group convolution. Kernel and data input channels do not match[3, 3, 8, 512]  [2, 30, 30, 512]
[09:45:07] /opt/workspace/DmlcTvm/src/relay/pass/do_replication.cc:66: Replication is not applied to group convolution. Kernel and data input channels do not match[3, 3, 8, 512]  [2, 30, 30, 512]
[09:45:07] /opt/workspace/DmlcTvm/src/relay/pass/do_replication.cc:66: Replication is not applied to group convolution. Kernel and data input channels do not match[3, 3, 8, 512]  [2, 30, 30, 512]
[09:45:07] /opt/workspace/DmlcTvm/src/relay/pass/do_replication.cc:66: Replication is not applied to group convolution. Kernel and data input channels do not match[3, 3, 16, 1024]  [2, 30, 30, 1024]
[09:45:07] /opt/workspace/DmlcTvm/src/relay/pass/do_replication.cc:66: Replication is not applied to group convolution. Kernel and data input channels do not match[3, 3, 16, 1024]  [2, 16, 16, 1024]
[09:45:07] /opt/workspace/DmlcTvm/src/relay/pass/do_replication.cc:66: Replication is not applied to group convolution. Kernel and data input channels do not match[3, 3, 16, 1024]  [2, 16, 16, 1024]
[09:45:07] /opt/workspace/DmlcTvm/src/relay/pass/do_replication.cc:66: Replication is not applied to group convolution. Kernel and data input channels do not match[3, 3, 16, 1024]  [2, 16, 16, 1024]
[09:45:07] /opt/workspace/DmlcTvm/src/relay/pass/do_replication.cc:66: Replication is not applied to group convolution. Kernel and data input channels do not match[3, 3, 16, 1024]  [2, 16, 16, 1024]

Here's the entire code I use. To reproduce this, please run the following commands:

git clone https://github.com/Tagussan/pretrained-models.pytorch
cd pretrained-models.pytorch
python pretrainedmodels/models/resnext_compile_neuron.py

Is this message ( Replication is not applied to group convolution. Kernel and data input channels do not match ) a normal one and okay to ignore, or some error is hidden behind this?

After compilation has been finished, I ran the same model on CPU and neuron. The result is slightly different (this is not negligible for me ):

.....
tensor([[-2.4844,  1.1328,  0.6797,  ..., -1.1797, -1.6406,  3.4062],
        [-2.4531,  0.7969,  0.9414,  ..., -1.1797, -1.2266,  2.3281]]) #neuron
tensor([[-2.5033,  1.1765,  0.6876,  ..., -1.1826, -1.7040,  3.4087],
        [-2.4803,  0.8026,  0.9759,  ..., -1.2000, -1.2469,  2.3515]], #CPU
       grad_fn=<AddmmBackward>)
tensor([[-2.1406,  0.4688,  0.6992,  ..., -1.1797, -1.5547,  2.6250],
        [-2.5625,  0.9883,  0.7578,  ..., -0.7891, -1.5781,  3.4688]]) #neuron
tensor([[-2.2070,  0.4666,  0.7328,  ..., -1.2209, -1.6023,  2.6719],
        [-2.5828,  0.9492,  0.7438,  ..., -0.7833, -1.6137,  3.4735]], #CPU
       grad_fn=<AddmmBackward>)
......

To reproduce this, please run the following commands in the same directory: python pretrainedmodels/models/python pretrainedmodels/models/resnext_test.py

I'm using inf1.xlarge instance with Deep Learning AMI (Ubuntu 18.04) Version 48.0 to compile this. The version of neuron-cc compiler is 1.5.5.0+3cc38c60b

Any fix for these?

aws-zejdaj commented 3 years ago

Hi @ Tagussan, happy to hear you were able to compile the model! The "replication" is a performance optimization that will not impact numerical accuracy. It may impact overall performance on models with large group convolutions.

Since your Resnext model is FP32 then the most likely source of lower accuracy is the auto-cast to BF16. Please rerun with --fp32-cast=matmult and let us know. You can find more details about Neuron-SDK casting options at https://awsdocs-neuron.readthedocs-hosted.com/en/latest/neuron-guide/perf/performance-tuning.html#mixed-precision .

Tagussan commented 3 years ago

Hi @aws-zejdaj , thank you for your quick responce! I understood that the "replication" massage does not affect numerical accuracy.

I compiled the model again with the flag --fp32-cast=matmult . Here's the console output: INFO:Neuron:Compiling with command line: '/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/bin/neuron-cc compile /tmp/tmp9fhh9a9w/graph_def.pb --framework TENSORFLOW --pipeline compile SaveTemps --output /tmp/tmp9fhh9a9w/graph_def.neff --io-config {"inputs": {"0:0": [[2, 3, 224, 224], "float32"]}, "outputs": ["Add_241:0"]} --fp32-cast=matmult --verbose 35'

And I ran the inference again:

....
tensor([[-2.5762,  0.6562,  0.8711,  ..., -1.4170, -1.4932,  2.6699],
        [-2.4492,  1.3848,  0.8560,  ..., -1.6299, -1.6797,  3.1602]]) #neuron
tensor([[-2.6108,  0.6671,  0.8822,  ..., -1.3947, -1.4970,  2.6480],
        [-2.4638,  1.4298,  0.8776,  ..., -1.6093, -1.6866,  3.1416]], #CPU
       grad_fn=<AddmmBackward>)
tensor([[-2.2598,  0.8794,  1.0449,  ..., -1.1162, -1.5898,  2.5781],
        [-2.6973,  1.0146,  0.7241,  ..., -1.3037, -1.8096,  3.5195]]) #neuron
tensor([[-2.2531,  0.8510,  1.0342,  ..., -1.1321, -1.6138,  2.6068],
        [-2.7459,  1.0023,  0.7449,  ..., -1.2888, -1.8594,  3.5549]], #CPU
       grad_fn=<AddmmBackward>)
....

Unfortunately, there still exists difference in the output and the amount of this difference is as same as the one without --fp32-cast=matmult flag.

aws-zejdaj commented 3 years ago

@ Tagussan Did the accuracy improve with --fp32-cast=matmult or same? If improved then you can try the highest accuracy setting (per the earlier link), the --fp32-cast=matmult-fp16

In parallel I'll be happy to debug locally. Is this the most up-to-date model / way to reproduce?

  git clone https://github.com/Tagussan/pretrained-models.pytorch
  cd pretrained-models.pytorch
  python pretrainedmodels/models/python pretrainedmodels/models/resnext_test.py
Tagussan commented 3 years ago

@aws-zejdaj I implemented some code to compare the accuracy quantitatively. I gathered 100 cosine value of the two output vectors for random inputs, each computed by CPU and neuron (Here's the code). Mean and std values of cosines are calculated to compare which compilation gave the better accuracy.

I ran each measurement for 3 times. Here I show the result (cosine = 1 means that the two vectors are identical):

No flag

Avg: 0.9997505527734757
Std: 9.108214240505213e-05
Avg: 0.9997314190864564
Std: 0.00013324180114077962
Avg: 0.9997494441270828
Std: 0.0001082998758929156

--fp32_mutmult

Avg: 0.9997505104541778
Std: 0.00011865711647356751
Avg: 0.9997607046365737
Std: 0.00011890440884484417
Avg: 0.9997695034742355
Std: 0.00011385409148532757

--fp32_mutmult-fp16

Avg: 0.9999963063001632
Std: 1.8007796275936768e-06
Avg: 0.9999960535764694
Std: 2.1913490999438255e-06
Avg: 0.9999961918592453
Std: 1.716239026511515e-06

The flag --fp32_mutmult improved the accuracy slightly, but the difference is small. However, the flag --fp32_mutmult-fp16 gave the two orders of magnitude accurate result! This accuracy is enough for me.

Yes, that is the latest procedure to reproduce. Please note that you need to run resnext_compile_neuron.py before running the resnext_test.py.

Because all of my concerns are solved, I'd like to close this issue. Thank you for your kind help!