I install the module by !pip3 install -qU git+https://github.com/harvardnlp/genbmm.
And then, I run the example code :
import genbmm
a = torch.rand(10, 3, 4).cuda().requires_grad_(True)
b = torch.rand(10, 4, 5).cuda().requires_grad_(True)
# Log-Sum-Exp
c = genbmm.logbmm(a, b)
I install the module by
!pip3 install -qU git+https://github.com/harvardnlp/genbmm
. And then, I run the example code :It occurs this error :