rballester / tntorch

Tensor Network Learning with PyTorch
https://tntorch.readthedocs.io/
GNU Lesser General Public License v3.0
283 stars 42 forks source link

tntorch.metrics.dot() is not perform the same calculation as the API documents described #47

Open zhaoran0072004 opened 8 months ago

zhaoran0072004 commented 8 months ago

tntorch.metrics.dot() is not perform the same calculation as the API documents described

Example: suppose t1 has shape 3 x 4 and t2 has shape 3 x 4 x 5 x 6. Then, tn.dot(t1, t2) will have shape 5 x 6.

It wasn't going do this contraction, instead the function ouputs a Runtime Error

a=torch.randn(3,4) b=torch.randn(3,4,5,6) tn.metrics.dot(a,b) Traceback (most recent call last): File "", line 1, in File "D:\anaconda\envs\TT-PINN\Lib\site-packages\tntorch\metrics.py", line 67, in dot return t1.flatten().dot(t2.flatten()) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: inconsistent tensor size, expected tensor [12] and src [360] to have the same number of elements, but got 12 and 360 elements respectively