Switch dataset formatting to TFRecords, from current version which just uses a .pt files within a directory. This should make streaming possible, which may important for real training. Slightly awkward because the preprocessing is in PyTorch and the training is in JAX. I'm told the best way to read off TFRecords during training is through tf.data
Switch dataset formatting to TFRecords, from current version which just uses a
.pt
files within a directory. This should make streaming possible, which may important for real training. Slightly awkward because the preprocessing is in PyTorch and the training is in JAX. I'm told the best way to read off TFRecords during training is through tf.data