wgrathwohl / JEM

Project site for "Your Classifier is Secretly an Energy-Based Model and You Should Treat it Like One"
Apache License 2.0
417 stars 63 forks source link

Dealing with divergence #4

Open AntixK opened 4 years ago

AntixK commented 4 years ago

Hello,

Your work is inspiring! I have the following problem when I try to run your code. During training, the loss often blows up and diverges. Could you help me as to how to deal with such divergences? It diverges even after turning off BatchNorm, having warmup-terations... often after 2 epochs.

Any help is appreciated. Thank you.

m-wiesner commented 4 years ago

If you read the appendix in the paper, they mention that most models had to be restarted from the last checkpoint using a different random seed after crashing. I tried running this code and I experienced the same thing, but as long as I restarted from the last checkpoint, things continued to train. It just took some manual intervention. I also recommend using a low learning rate and bumping up the # inner SLGD iterations per outer minibatch iteration. That should help some with the stability.

wgrathwohl commented 4 years ago

Hello, thanks for your kind words. Yes, as m-weisner commented, this is the exact strategy I used and should work. I know it is less than ideal but EBMs are still somewhat brittle these days. Another alternative I've found to work is to place a small l2 penalty on the energy (this. has been done in prior EBM work) with strength around .1. This should keep the energy values near zero and make training more stable.

USTC-yzy1996 commented 3 years ago

Hello, I'm a new guy on Energy Models and I also met the same divergence problem during running the code. I see two solutions. One is using different random seed, and the other is to place a small l2 penalty. Could you two (m-wiesner & wgrathwohl) please tell me how to implement in the code? For seed, I loaded the checkpoint and the new random seed did not work because the parameters were all from the checkpoint file. For l2 penalty, I can only find [l_p_x], [l_p_y_given_x] and [l_p_x_y] these three loss functions. Where can I find l2 penalty on the energy from the code? Thanks a lot! :)

m-wiesner commented 3 years ago

I don't know if that code used l2-norm on the network outputs. In my experience this didn't actually help with divergence. @wgrathwohl--what was your experience with this?

On Mar 12, 2021 5:17 AM, IndescribableMask @.***> wrote:

  External Email - Use Caution

Hello, I'm a new guy on Energy Models and I also met the same divergence problem during running the code. I see two solutions. One is using different random seed, and the other is to place a small l2 penalty. Could you two (m-wiesner & wgrathwohl) please tell me how to implement in the code? For seed, I loaded the checkpoint and the new random seed did not work because the parameters were all from the checkpoint file. For l2 penalty, I can only find [l_p_x], [l_p_y_given_x] and [l_p_x_y] these three loss functions. Where can I find l2 penalty on the energy from the code? Thanks a lot! :)

— You are receiving this because you commented. Reply to this email directly, view it on GitHubhttps://nam02.safelinks.protection.outlook.com/?url=https%3A%2F%2Fgithub.com%2Fwgrathwohl%2FJEM%2Fissues%2F4%23issuecomment-797388595&data=04%7C01%7Cwiesner%40jhu.edu%7Cbfd86b01a42c49d9c05c08d8e540057a%7C9fa4f438b1e6473b803f86f8aedf0dec%7C0%7C0%7C637511410421826403%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C1000&sdata=xs9Mr7KV1NuCVShnr%2Fwhng2KYFRjTJUrDgdcMbW5ihE%3D&reserved=0, or unsubscribehttps://nam02.safelinks.protection.outlook.com/?url=https%3A%2F%2Fgithub.com%2Fnotifications%2Funsubscribe-auth%2FAEAQESEZRNSPSXE7GDGHKT3TDHS27ANCNFSM4NEZENJA&data=04%7C01%7Cwiesner%40jhu.edu%7Cbfd86b01a42c49d9c05c08d8e540057a%7C9fa4f438b1e6473b803f86f8aedf0dec%7C0%7C0%7C637511410421836398%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C1000&sdata=2gnhK8p6UcQTWXr%2FfZXKwQ4te6XnwH3ilfKHHH9ef6c%3D&reserved=0.

wgrathwohl commented 3 years ago

Hi,

Sorry about the slow response. Been very busy. About L2 reg...I never did that in the JEM work. Some people have found it to work. I have not had much success with it. If you want to use it, I'd recommend trying a strength of like .0001 and going up from there.

You would compute it like

fd, fq = f(x_d), f(x_q) obj = fd.mean() - fq.mean() loss = -obj + args.l2 * ((fd2).mean() + (fq2).mean())

I have had more luck with a gradient penalty placed on the energy evaluated on the training data

fd, fq = f(x_d), f(x_q) obj = fd.mean() - fq.mean() gd = torch.autograd.grad(fd.sum(), x_d)[0].view(x_d.size(0), -1) loss = -obj + args.l2_grad * gd.norm(2, -1).mean()

I'd try setting l2_grad to something like .01.

I hope this helps! :)

On Fri, Mar 12, 2021 at 12:48 PM Matthew Wiesner @.***> wrote:

I don't know if that code used l2-norm on the network outputs. In my experience this didn't actually help with divergence. @wgrathwohl--what was your experience with this?

On Mar 12, 2021 5:17 AM, IndescribableMask @.***> wrote:

External Email - Use Caution

Hello, I'm a new guy on Energy Models and I also met the same divergence problem during running the code. I see two solutions. One is using different random seed, and the other is to place a small l2 penalty. Could you two (m-wiesner & wgrathwohl) please tell me how to implement in the code? For seed, I loaded the checkpoint and the new random seed did not work because the parameters were all from the checkpoint file. For l2 penalty, I can only find [l_p_x], [l_p_y_given_x] and [l_p_x_y] these three loss functions. Where can I find l2 penalty on the energy from the code? Thanks a lot! :)

— You are receiving this because you commented. Reply to this email directly, view it on GitHub< https://nam02.safelinks.protection.outlook.com/?url=https%3A%2F%2Fgithub.com%2Fwgrathwohl%2FJEM%2Fissues%2F4%23issuecomment-797388595&data=04%7C01%7Cwiesner%40jhu.edu%7Cbfd86b01a42c49d9c05c08d8e540057a%7C9fa4f438b1e6473b803f86f8aedf0dec%7C0%7C0%7C637511410421826403%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C1000&sdata=xs9Mr7KV1NuCVShnr%2Fwhng2KYFRjTJUrDgdcMbW5ihE%3D&reserved=0>, or unsubscribe< https://nam02.safelinks.protection.outlook.com/?url=https%3A%2F%2Fgithub.com%2Fnotifications%2Funsubscribe-auth%2FAEAQESEZRNSPSXE7GDGHKT3TDHS27ANCNFSM4NEZENJA&data=04%7C01%7Cwiesner%40jhu.edu%7Cbfd86b01a42c49d9c05c08d8e540057a%7C9fa4f438b1e6473b803f86f8aedf0dec%7C0%7C0%7C637511410421836398%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C1000&sdata=2gnhK8p6UcQTWXr%2FfZXKwQ4te6XnwH3ilfKHHH9ef6c%3D&reserved=0

.

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/wgrathwohl/JEM/issues/4#issuecomment-797651995, or unsubscribe https://github.com/notifications/unsubscribe-auth/ADYQS4VXDCEXGDXVAE6HGGDTDJHW5ANCNFSM4NEZENJA .

-- Will Grathwohl

Graduate Student Researcher Machine Learning Group University of Toronto / Vector Institute