Open BenjaminDug opened 5 months ago
Most likely this has to do with GPU memory prefetching. Both tf.data and torch DataLoader can do prefetching, but when using a different backend they have to convert to the right tensor type in CPU memory, which cancels the benefits of prefetching. I believe this could be optimized further. @haifeng-jin to advise.
I tried to withdraw the prefetch in tf dataset, the epoch go from 5s to 6s for the torch backend
I observe that my cpu are full with torch backend and tf dataset pipeline:
Below there is the picture of tensorflow backend with the same tf dataset pipeline
I will look into it.
BTW, model.fit()
should work with torch DataLoader directly. @BenjaminDug
Thank you !
Yes I know that with keras 3.0 we can use dataloader in the .fit() but in my use case I have tfrecords and I need to use a torch model. With tensorflow, the best pipeline is tf.dataset but I have a hope with keras 3.0 to use tf.dataset with a torch model for loading efficiently my tfrecords using torch backend.
I hope that the data stay in low level and there is no python instruction for converting data from tf.dataset. Some times ago, I have already created a dataloader which loaded tfrecords and converted tf.tensor data to torch.tensor. It was really slow because of this conversion in python. So I had to give up tfrecords for this time.
@BenjaminDug my team had same setup and what we have found out is that you need to make sure numpy->torch copy is overlapping with the compute. tf does it natively with tf dataset, but if you are going with tf.dataset.as_numpy_iterator()
for torch, you need to handle the to gpu buffering yourself.
a code snippet would look like below:
def pin_mem_fn(b: dict[str, torch.Tensor]):
return {k: v.pin_memory() for k, v in b.items()}
future_batch = executor.submit(pin_mem_fn, next(iter))
for step in ...:
batch = future_batch.result()
future_batch = executor.submit(pin_mem_fn, next(iter))
<do work with batch>
this is similar to what data loader has also but need to be done manually if you are not using a torch data loader. if you are okay with using torch data loader, you can also wrap tf dataset in a torch iterable dataset and then use data loader.
I believe @hertschuh has more insights on this issue. It was either resolved already or very hard to resolve.
Assigning to @hertschuh temporarily. Feel free to assign it back.
@BenjaminDug ,
I did some rework of the "DataAdapters" after this bug was created, however, I don't believe the performance of your specific use case has changed.
After doing some benchmarking, I came to the conclusion that there is no easy way to improve the performance of feeding a tf.data.Dataset
to Torch model. The slowdown comes mostly from the copying of tensors.
tf.Tensor
to a torch tensor requires a copy. Otherwise Torch detects that you're creating a tensor that is mutable from memory that Torch doesn't own and prints a warning that “mutating leads to unknown behavior”. I don't believe this is an actual issue in the training scenario, but I didn't want to enable this behavior that creates a warning.The bottom line is that you get better performance with a Torch DataLoader
.
This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.
Hello,
First I want to thank you for the release of this amazing framework. I am a tensorflow user, but sometimes I don't have choice to use pytorch and I don't like the dataloader (I prefer tf.dataset which is, according to me, the fastest when there are a lot of I/O).
I decided to try to implement on the MNIST a very simple case where I use different backend with the same data pipeline and a torch model with dataloader:
1 - jax backend with tf.dataset - model is full keras layers - virtualenv from you requirement-jax-gpu.txt 2 - tensorflow backend with tf.dataset - model is full keras layers - virtualenv from you requirement-tensorflow-gpu.txt 3 - torch backend with tf.dataset - model is nn module in the init part of a keras model - virtualenv from you requirement-torch-gpu.txt 4 - torch backend with a dataloader and a torch training loop (no keras here) - virtualenv from you requirement-torch-gpu.txt
hardware: gpu rtx 2070 - cpu i5 9400f
install: ubuntu 22.04 - cuda 12.2 - cudnn 8.9.5.30 - driver 535.129.03
The code for the 1, 2 and 3 are just below:
The code for the 4 is just below:
I have the following stdout for 1 ( jax backend with tf.dataset - model is full keras layers - virtualenv from you requirement-jax-gpu.txt):
I have the following stdout for 2 ( tensorflow backend with tf.dataset - model is full keras layers - virtualenv from you requirement-tensorflow-gpu.txt):
I have the following stdout for 3 ( torch backend with tf.dataset - model is nn module in the init part of a keras model - virtualenv from you requirement-torch-gpu.txt):
Of course, I changed the os.environ['KERAS_BACKEND'] for each backend.
We can see that Jax is pretty good (1s an epoch), tensorflow is good too (2s an epoch), but torch with tf.dataset is very much slower (5s an epoch).
I have the following stdout for 4 ( torch backend with a dataloader and a torch training loop (no keras here) - virtualenv from you requirement-torch-gpu.txt):
We can see that the torch backend, with a dataloader with a torch training loop is around 2s an epoch like tensorflow with tf.dataset pipeline. The latter has the same speed as dataloader because everything is already in cpu ram. I know that tf.dataset is really better when we have tfrecords.
I know that the torch model is a bit different (padding SAME in tensorflow whereas is VALID in torch) and the order of channel is not the same between tensorflow and torch. But the comparison is between the same torch model, but with 2 pipelines differents.
Why the torch backend with a torch model in keras with tf.dataset is slower than the other with dataloader in a pytorch training loop ?
Is there a conversion a kind of conversion between tf.dataset (maybe tf.tensor ?) and input model which are torch tensor ?
Maybe I have done something bad in the code ? Maybe this case is too much easy to make a good comparison ?
I really want to use tf.dataset with torch backend as tensorflow can do.
Thank you for you help