This PR added a small functionality for converting a model into different precision. Interestingly, the compiler would complain if bfloat16 or float16 is used, when the input data is also float16 or bfloat16. But if the data is float32, the code will execute fine. I think this is because some type promoting happened inside the code
This PR added a small functionality for converting a model into different precision. Interestingly, the compiler would complain if bfloat16 or float16 is used, when the input data is also float16 or bfloat16. But if the data is float32, the code will execute fine. I think this is because some type promoting happened inside the code
This PR bump the jax version up as well.