elvisyjlin / AttGAN-PyTorch

AttGAN PyTorch Arbitrary Facial Attribute Editing: Only Change What You Want
MIT License
248 stars 61 forks source link

multi_gpu #6

Closed thinkerthinker closed 5 years ago

thinkerthinker commented 5 years ago

I still have a problem with multi_gpu. I am running your modified code, but there are some issues in zs_a = self.G.encode(img_a): AttributeError: 'DataParallel' object has no attribute 'encode'. When I changed zs_a = self.G.encode(img_a) to zs_a = self.G.module.encode(img_a) , the code can run but the loss is diverging. I don't know where I made the mistake.

elvisyjlin commented 5 years ago

In my recent commit 5c9643b967dd8c051c71413f03872a3cb80a79ab, I updated the forward function of generator. So that we can do different operation via the forward function.

img_2 = G(img, attr, mode='enc-dec')
latent = G(img, mode='enc')
img_2 = G(latent, attr, mode='dec')

However, it doesn't solve the gradient explosion problem in multi-gpu training. It is again caused by PyTorch 1.0.0 and it hasn't be fixed at all (see this issue). Since nn.DataParallel() has many problems in v1.0.0, my suggestion is to use PyTorch v0.4.0 instead of v1.0.0 or above. I've tested the multi-GPU training under v0.4.0.

Please do

pip3 install --upgrade torch==0.4.0

If you don't want to overwrite your global pip environment, you can use virtualenv.

# Go to AttGAN-PyTorch/ folder
pip3 install virtualenv
virtualenv myenv
source myenv/bin/activate
pip3 install -r requirements.txt
pip3 install --upgrade torch==0.4.0
thinkerthinker commented 5 years ago

Thank you for your reply. I found that nn.DataParallel() does not work for wgan in PyTorch 1.0.0 , I changed to dcgan, the loss curve looks normal.

elvisyjlin commented 5 years ago

With regard to wean, I think it depends on how it is implemented. If the gradient penalty is wrapped in a function like this pytorch-wgan, pytorch 1.0.0 will fail when you use DataParallel. Rolling back to 0.4.0 or 0.4.1 is the only solution I could find.

thinkerthinker commented 5 years ago

ok. I will try to roll back to 0.4.0 or 0.4.1. Thank you again!

elvisyjlin commented 5 years ago

Thank you for pointing out the bug in my multi-gpu, too.