ConnorJL / GPT2

An implementation of training for GPT2, supports TPUs
MIT License
1.42k stars 338 forks source link

Unable to predict with bfloat16 model #1

Closed kizinfo closed 5 years ago

kizinfo commented 5 years ago

Can train a bfloat16 model but prediction on either GPU or CPU gives missing kernel op for bfloat16 for 'Rsqrt'. Have you been able to predict using bfloat16 models?

Also, would it be possible to do batch gradient averaging to simulate larger batch size on TPU without requiring more memory?

ConnorJL commented 5 years ago

Yes I've run into that problem with blfoat16 as well. I haven't yet found a satisfying solution, but I think setting the model to float32 during prediction might work. I'll look into it more when I have time.

I'm not sure how tricky that would be to do on TPUs, and currently have no intentions on implementing it I'm afraid.

kizinfo commented 5 years ago

Thanks. some fantastic work you have done here.