Closed Mistsink closed 4 months ago
Hi, we don't have plans on implementing the kernels in BF16 - but PRs are welcome if you'd like to contribute.
As a workaround, perhaps you could try explicitly casting the inputs to float32 before calling ULTRA and casting back the output from float32 to bfloat16
thank you very much, I did do that, but I'm concerned about the possibility of encountering nan or inf when converting from float32 to bf16.
I wouldn't worry about inf or nan - there is no part of the code that would be that sensitive to 16bit vs 32bit precision, GNNs inside use mostly additions and multiplications with LayerNorm
Thank you very much for enlightening me. I have gained a lot from it and will now close this issue.
Hello, thank you very much for your appreciation of our work on foundational models. I've incorporated information from LLM in my subsequent attempts. During fine-tuning, I utilized PEFT-related techniques and set the dtype to bf16. However, I encountered the error as described in the title when executing RSPMM. Would it be possible to modify and resolve this issue conveniently?