Closed phu0ngng closed 2 weeks ago
Hi, I added another small change that fixed the issue:
FAILED tests/jax/test_single_gpu_mnist.py::TestMNIST::test_te_bf16 - ValueError: The repository for mnist contains custom code which must be executed to correctly ...
This issue does not relate to the previous issue, but without this fix the CI won't come back clean so I decided to push it to the same PR.
/te-ci jax
Description
In JAX, we call
dbias_cast_transpose
anddact_dbias_cast_transpose
with empty tensors to get the workspace size before calling the actual functions. So this fix is needed.Type of change
Changes