tristandeleu / pytorch-meta

A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch
https://tristandeleu.github.io/pytorch-meta/
MIT License
1.98k stars 256 forks source link

Is Torchmeta compatible with Dataparallel? #43

Closed smounsav closed 4 years ago

smounsav commented 4 years ago

Hi, is torchmeta compatible with dataparallel? I have troubles training a MetaModule wrapped by dataparallel. It seems get_subdict has problems with the layer names starting with "module." Thanks for your help!

mahf93 commented 4 years ago

I also have the same problem, would appreciate an answer.

tristandeleu commented 4 years ago

Hi, sorry for the late reply. I have had similar other feedbacks regarding compatibility between MetaModules and Dataparallel. I will look into it.

tristandeleu commented 4 years ago

Can you provide a minimal example of this issue?

Qunxiang commented 4 years ago

Can you provide a minimal example of this issue?

Hi, when I use multi gpus to train the model like this,

model = torch.nn.DataParallel(model)

it raises AttributeError: 'DataParallel' object has no attribute 'meta_parameters'

Thanks for your help!

tristandeleu commented 4 years ago

Thank you! I will test this, but it looks like this might require an extension of DataParallel in Torchmeta to handle the meta_parameters attribute.

tristandeleu commented 4 years ago

I have added support for DataParallel. Unfortunately it does not work out of the box with torch.nn.DataParallel, but I have added DataParallel in torchmeta.modules which is identical to the torch version, with full support for MetaModule instances (pass a params argument in forward(), access meta_parameters()). This means that you probably would have to update your code replacing torch.nn.DataParallel with torchmeta.modules.DataParallel. Here is a basic example:

import torch
import torch.nn as nn

from torchmeta.modules import MetaSequential, MetaLinear
from torchmeta.modules import DataParallel

model = MetaSequential(
    MetaLinear(2, 3),
    nn.ReLU(),
    MetaLinear(3, 1))
model = DataParallel(model) # Instead of model = nn.DataParallel(model)
model.to(device=torch.device('cuda:0'))

Please let me know if that works for you, and if you find any bugs. I am not usually using multi-GPU, so I won't be able to fully test it myself.

smounsav commented 4 years ago

Thanks a lot! I'll try it and let you know I find any bug.