Closed kizinfo closed 5 years ago
Yes I've run into that problem with blfoat16 as well. I haven't yet found a satisfying solution, but I think setting the model to float32 during prediction might work. I'll look into it more when I have time.
I'm not sure how tricky that would be to do on TPUs, and currently have no intentions on implementing it I'm afraid.
Thanks. some fantastic work you have done here.
Can train a bfloat16 model but prediction on either GPU or CPU gives missing kernel op for bfloat16 for 'Rsqrt'. Have you been able to predict using bfloat16 models?
Also, would it be possible to do batch gradient averaging to simulate larger batch size on TPU without requiring more memory?