octree-nn / ocnn-pytorch

Octree-based Sparse Convolutional Neural Networks
MIT License
150 stars 16 forks source link

HRNet does not work with AMP but LeNet Does #23

Closed harryseely closed 1 year ago

harryseely commented 1 year ago

Hello I am using pytorch lightning 16-bit automatic mixed precision training (AMP) with OCNN HRNet and I am running into the error traceback:

grad_out = octree_conv.backward_gemm(grad_out, grad, weights) buffer = torch.mm(grad[start:end], weights.flatten(0, 1).t()) RuntimeError: expected scalar type Half but found Float

However, when I run the same code with OCNN-LeNet, it works just fine. Any idea why this might be happening?

wang-ps commented 1 year ago

Sorry, I have not tested the code with AMP and I currently could not provide more indepth suggestions.

harryseely commented 1 year ago

That's ok, thank you for all your hard work on this!

If you do get a chance to test with AMP that would be great since it can provide substantial speed benefits.

wang-ps commented 1 year ago

Thanks for your advice. I will tune the code to supported AMP when I am free in the future.

Pull requests are always welcomed!

harryseely commented 1 year ago

I'll see what I can do to get it working...