Closed harryseely closed 1 year ago
Sorry, I have not tested the code with AMP and I currently could not provide more indepth suggestions.
That's ok, thank you for all your hard work on this!
If you do get a chance to test with AMP that would be great since it can provide substantial speed benefits.
Thanks for your advice. I will tune the code to supported AMP when I am free in the future.
Pull requests are always welcomed!
I'll see what I can do to get it working...
Hello I am using pytorch lightning 16-bit automatic mixed precision training (AMP) with OCNN HRNet and I am running into the error traceback:
grad_out = octree_conv.backward_gemm(grad_out, grad, weights) buffer = torch.mm(grad[start:end], weights.flatten(0, 1).t()) RuntimeError: expected scalar type Half but found Float
However, when I run the same code with OCNN-LeNet, it works just fine. Any idea why this might be happening?