Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.16k stars 3.37k forks source link

support of other datatypes for Batches #2465

Closed Tieftaucher closed 4 years ago

Tieftaucher commented 4 years ago

🚀 Feature

My proposal is a support of third party data structures as Batches. At the moment you need to overwrite the transfer_batch_to_device method of your model, if your Batch is not a collection or one of the other supported data types. My suggestion would be to accept all kinds of data types, as long as they have a to(device)-method.

Motivation

I want to use pytorch_geometric, but there an own dataloader is used and an own Batch-datatype. So I had some trouble using it together with pytorch lighning. After some struggles I figured out to overwrite the transfer_batch_to_device-Method like this:

class Net(pl.LightningModule):
....
    def transfer_batch_to_device(self,batch, device):
        return batch.to(device)

At least I think it would be nice to mention this necessarity in the docs. Or change the default behaviour of transfer_batch_to_device, so that it is no longer necessary.

Pitch

The transfer_batch_to_device should accept all datatypes that contain a "to(device)" method.

Alternatives

Alternative there should be a mentioning in the documentation for using non default dataloader and Batches

Additional context

I saw https://github.com/PyTorchLightning/pytorch-lightning/pull/1756 but couldnt figure out, if this solves my problem and is just not merged yet or not. If it does, sorry for the extra work.

Thank you for the nice library and all your work =)

github-actions[bot] commented 4 years ago

Hi! thanks for your contribution!, great first issue!

nghorbani commented 4 years ago

This is also affecting me. I notice when using torch-geometric also the batchsize is not effected; i.e. if I set batch_size to 16 for example each call to forward of the lightning module is given 16 data points on "CPU". and all of the data points in each forward call on different gpus are the same! Using pl 0.9.0 on Ubuntu 20.04

awaelchli commented 4 years ago

@Tieftaucher I fixed it here: #2335 As long as your datatype implements .to(device), it will call that directly. You don't have to override transfer_batch_to_device in this case.

@nghorbani could you open a separate issue about this. If you provide me some code I can help. But note that, last time I checked, torchgeometry did not support distributed multi gpu (scatter, gather).