kuprel / min-dalle

min(DALL·E) is a fast, minimal port of DALL·E Mini to PyTorch
MIT License
3.48k stars 256 forks source link

Bfloat16 - precision loss #95

Closed ArulselvanMadhavan closed 1 year ago

ArulselvanMadhavan commented 1 year ago

Hi, Thanks for this library. I found this to be very helpful in understanding DALL-E architecture.

https://github.com/kuprel/min-dalle/blob/main/min_dalle/models/dalle_bart_decoder.py#L172

I have been trying to reproduce the results in bfloat16. Everything works fine except for this matmul in the decoder. The difference in results between fp32 and bfloat16 were significant enough to affect the classifier-free guidance results. This affect the image generation. Do you have any suggestions on minimizing the significance in the results between fp32 and bfloat16?

Thank you