Closed Tieftaucher closed 4 years ago
Hi! thanks for your contribution!, great first issue!
This is also affecting me. I notice when using torch-geometric also the batchsize is not effected; i.e. if I set batch_size to 16 for example each call to forward of the lightning module is given 16 data points on "CPU". and all of the data points in each forward call on different gpus are the same! Using pl 0.9.0 on Ubuntu 20.04
@Tieftaucher I fixed it here: #2335
As long as your datatype implements .to(device)
, it will call that directly. You don't have to override transfer_batch_to_device
in this case.
@nghorbani could you open a separate issue about this. If you provide me some code I can help. But note that, last time I checked, torchgeometry did not support distributed multi gpu (scatter, gather).
🚀 Feature
My proposal is a support of third party data structures as Batches. At the moment you need to overwrite the transfer_batch_to_device method of your model, if your Batch is not a collection or one of the other supported data types. My suggestion would be to accept all kinds of data types, as long as they have a to(device)-method.
Motivation
I want to use pytorch_geometric, but there an own dataloader is used and an own Batch-datatype. So I had some trouble using it together with pytorch lighning. After some struggles I figured out to overwrite the transfer_batch_to_device-Method like this:
At least I think it would be nice to mention this necessarity in the docs. Or change the default behaviour of transfer_batch_to_device, so that it is no longer necessary.
Pitch
The transfer_batch_to_device should accept all datatypes that contain a "to(device)" method.
Alternatives
Alternative there should be a mentioning in the documentation for using non default dataloader and Batches
Additional context
I saw https://github.com/PyTorchLightning/pytorch-lightning/pull/1756 but couldnt figure out, if this solves my problem and is just not merged yet or not. If it does, sorry for the extra work.
Thank you for the nice library and all your work =)