Open xmax1 opened 1 year ago
Hi @xmax1
Jax doesn't recognise tf types as jnp - this line in get_dataset prefetch your batch to the devices and shard it for you - so you shouldn't need to change the input_dtype if the prefetch_to_device function runs correctly (I just run the notebook with GPU on colab and it seems fine)
it = jax_utils.prefetch_to_device(it, 2)
Hi, thanks for the great work!
There is an assertion error when checking the dataset, which is confusing because as far as I understand it should fail for anyone.
Possibly a version issue (maybe some version of jax recognises tf types as jnp?).
get_dataset shown below with fixing lines commented out
For anyone reading I'm using 0.3.21 CUDA (not TPU).