Closed pfeatherstone closed 3 years ago
I'm using pytorch 1.6.0
thank you for your issue! I will investigate it after the 15th of August.
I was able to trace the issue down to these two lines 89-90 in batchnorm.py, specifically the jit compiler did not like
cov_uu, cov_uv = running_cov[0, 0], running_cov[0, 1]
cov_vu, cov_vv = running_cov[1, 0], running_cov[1, 1]
which is extremely odd to me. Especially, since running_cov
here is either None
or a 3d tensor of shape 2 x 2 x ...
, and thus cannot be indexed outside of its range. I was able to resolve the particular case you provided with the follownig patch:
cov_uu, cov_uv, cov_vu, cov_vv = running_cov.reshape(4, -1)
which I don't particularly like, since it introduces an unnecessary reshaping and iteration (over dim 0).
Furthermore, it turns out that torch.var
operator is not supported by the onnx opset. This can be circumvented by using
var = (tensor * tensor).mean(dim=axes) + nugget
instead of
var = tensor.var(dim=axes, unbiased=False) + nugget
This indeed computes the variance, since tensor is centred during calculation of the batch average.
However this sort of quirks seems to me to signify that some understanding is missing on my part, and that many other objects in cplxmodule
might need testing for ONNX jit compatibility. I am composing a unit test suite for ONNX compatibility.
Thank you this is great.
So applying the following fix works:
cov_uu, cov_uv, cov_vu, cov_vv = running_cov.reshape(4, -1)
Any reason why this can't be committed?
Sod it, i've added a PR: https://github.com/ivannz/cplxmodule/pull/12
The issue has been fixed by pr #14
Steps to reproduce
I get stack trace: