ml-struct-bio / cryodrgn

Neural networks for cryo-EM reconstruction
http://cryodrgn.cs.princeton.edu
GNU General Public License v3.0
305 stars 75 forks source link

nan during train_vae #18

Open kimdn opened 4 years ago

kimdn commented 4 years ago

Maybe this is not cryodrgn specific issue, but more general vae problem???.

Anyway, have you ever seen "nan" like below?

Thank you

--

2020-08-15 07:13:21     # =====> Epoch: 2457 Average gen loss = 0.838749, KLD = 3.326967, total loss = 0.839008; Finished in 0:03:07.132678
2020-08-15 07:17:01     # =====> Epoch: 2458 Average gen loss = 0.838748, KLD = 3.325715, total loss = 0.839007; Finished in 0:03:02.530018
2020-08-15 07:20:41     # =====> Epoch: 2459 Average gen loss = nan, KLD = 4347.587754, total loss = nan; Finished in 0:03:05.964074
2020-08-15 07:24:22     # =====> Epoch: 2460 Average gen loss = nan, KLD = 297.862892, total loss = nan; Finished in 0:03:02.167709
2020-08-15 07:28:03     # =====> Epoch: 2461 Average gen loss = 0.908481, KLD = 8.505987, total loss = 0.909143; Finished in 0:03:03.817373
2020-08-15 07:31:55     # =====> Epoch: 2462 Average gen loss = 1.12218, KLD = 7.893573, total loss = 1.122793; Finished in 0:03:10.836993
2020-08-15 07:35:33     # =====> Epoch: 2463 Average gen loss = 0.840698, KLD = 3.719698, total loss = 0.840987; Finished in 0:03:03.367703
2020-08-15 07:39:14     # =====> Epoch: 2464 Average gen loss = 0.928574, KLD = 3.505119, total loss = 0.928846; Finished in 0:03:08.811053
2020-08-15 07:42:48     # =====> Epoch: 2465 Average gen loss = 0.838928, KLD = 3.334144, total loss = 0.839187; Finished in 0:03:00.716042
2020-08-15 07:46:29     # =====> Epoch: 2466 Average gen loss = 0.8393, KLD = 3.351988, total loss = 0.839561; Finished in 0:03:02.098446
2020-08-15 07:50:07     # =====> Epoch: 2467 Average gen loss = 0.839245, KLD = 3.370683, total loss = 0.839507; Finished in 0:03:02.463680
2020-08-15 07:53:52     # =====> Epoch: 2468 Average gen loss = 0.839711, KLD = 3.369308, total loss = 0.839973; Finished in 0:03:07.060146
2020-08-15 07:57:28     # =====> Epoch: 2469 Average gen loss = 0.838877, KLD = 3.337764, total loss = 0.839137; Finished in 0:03:00.234357
2020-08-15 08:01:08     # =====> Epoch: 2470 Average gen loss = 0.839412, KLD = 3.340018, total loss = 0.839672; Finished in 0:03:06.454027
2020-08-15 08:05:04     # =====> Epoch: 2471 Average gen loss = 0.83875, KLD = 3.335607, total loss = 0.839010; Finished in 0:03:14.294994
2020-08-15 08:08:58     # =====> Epoch: 2472 Average gen loss = 0.838798, KLD = 3.332927, total loss = 0.839057; Finished in 0:03:17.291361
2020-08-15 08:12:41     # =====> Epoch: 2473 Average gen loss = 0.839363, KLD = 3.337848, total loss = 0.839623; Finished in 0:03:02.637869
2020-08-15 08:16:24     # =====> Epoch: 2474 Average gen loss = 0.83879, KLD = 3.334751, total loss = 0.839049; Finished in 0:03:06.469525
2020-08-15 08:20:06     # =====> Epoch: 2475 Average gen loss = 0.838746, KLD = 3.335152, total loss = 0.839005; Finished in 0:03:02.485381
2020-08-15 08:23:46     # =====> Epoch: 2476 Average gen loss = 0.839, KLD = 3.334727, total loss = 0.839259; Finished in 0:03:03.908909
2020-08-15 08:27:28     # =====> Epoch: 2477 Average gen loss = 0.838743, KLD = 3.335539, total loss = 0.839003; Finished in 0:03:06.945390
2020-08-15 08:31:05     # =====> Epoch: 2478 Average gen loss = 0.838759, KLD = 3.336668, total loss = 0.839019; Finished in 0:03:01.373870
2020-08-15 08:34:52     # =====> Epoch: 2479 Average gen loss = 0.838768, KLD = 3.334897, total loss = 0.839027; Finished in 0:03:04.845571
2020-08-15 08:38:32     # =====> Epoch: 2480 Average gen loss = 0.838751, KLD = 3.331835, total loss = 0.839010; Finished in 0:03:03.818442
2020-08-15 08:42:07     # =====> Epoch: 2481 Average gen loss = nan, KLD = 33385075859247759360.000000, total loss = nan; Finished in 0:03:01.271044
2020-08-15 08:45:51     # =====> Epoch: 2482 Average gen loss = nan, KLD = 319211574631489344.000000, total loss = nan; Finished in 0:03:05.973848
2020-08-15 08:49:58     # =====> Epoch: 2483 Average gen loss = nan, KLD = 64188202244660896.000000, total loss = nan; Finished in 0:03:24.483174
zhonge commented 3 years ago

Wow that is a lot of epochs! Yes, I have run into nans during training before, but only with zdim=1 and usually triggered by impurities/non-standard images in the dataset.

By the way, I typically train for much fewer epochs. Training is typically bottlenecked by model updates, so I usually stick to the default batch size of 8 for more frequent updates and train for 25 epochs or so. Of course this is completely dependent on your dataset characteristics, and I encourage experimentation.

Also, depending on the image size/model size/latent variable dimension, overfitting is possible, especially with more epochs of training. I can recommend some training settings if you want to share some of the details of your dataset or reach out directly.

kimdn commented 3 years ago

Wow that is a lot of epochs! Yes, I have run into nans during training before, but only with zdim=1 and usually triggered by impurities/non-standard images in the dataset.

By the way, I typically train for much fewer epochs. Training is typically bottlenecked by model updates, so I usually stick to the default batch size of 8 for more frequent updates and train for 25 epochs or so. Of course this is completely dependent on your dataset characteristics, and I encourage experimentation.

Also, depending on the image size/model size/latent variable dimension, overfitting is possible, especially with more epochs of training. I can recommend some training settings if you want to share some of the details of your dataset or reach out directly.

I see. It appears that overfitting means that points of the latent space are not representative for input data. Therefore, it likely gives meaningless content once decoded like classic autoencoder. This overfitting may happen when too much emphasis is placed on to minimize reconstruction loss. (As I understand https://towardsdatascience.com/understanding-variational-autoencoders-vaes-f70510919f73).

I'm curious how cryodrgn measures/estimates whether the model is overfitted or not. It is overfitted if KLD is larger than certain threshold? (since without kullback-leibler divergence, "vae" will overfit focusing on minimizing reconstruction loss only?)

My data set is large (284,133 particles of experimental cryo-EM images, not synthetically simulated one). After refinement in cryosparc, I extracted these number of particles by relion.