RasmussenLab / vamb

Variational autoencoder for metagenomic binning
MIT License
250 stars 45 forks source link

[Feature Request] Early stopping parameter and a ∆loss (delta_loss) parameter #73

Open jolespin opened 3 years ago

jolespin commented 3 years ago

I've noticed

epochs = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,261,262,263,264,265,266,267,268,269,270,271,272,273,274,275,276,277,278,279,280,281,282,283,284,285,286,287,288,289,290,291,292,293,294,295,296,297,298,299,300,301,302,303,304,305,306,307,308,309,310,311,312,313,314,315,316,317,318,319,320,321,322,323,324,325,326,327,328,329,330,331,332,333,334,335,336,337,338,339,340,341,342,343,344,345,346,347,348,349,350,351,352,353,354,355,356,357,358,359,360,361,362,363,364,365,366,367,368,369,370,371,372,373,374,375,376,377,378,379,380,381,382,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,399,400,401,402,403,404,405,406,407,408,409,410,411,412,413,414,415,416,417,418,419,420,421,422,423,424,425,426,427,428,429,430,431,432,433,434,435,436,437,438,439,440,441,442,443,444,445,446,447,448,449,450,451,452,453,454,455,456,457,458,459,460,461,462,463,464,465,466,467,468,469,470,471,472,473,474,475,476,477,478,479,480,481,482,483,484,485,486,487,488,489,490,491,492,493,494,495,496,497,498,499,500]

losses = [1.022028,0.976815,0.946165,0.939372,0.936472,0.934478,0.93287,0.931121,0.928942,0.926493,0.924655,0.922991,0.921227,0.917767,0.913316,0.910343,0.908492,0.906697,0.906121,0.905017,0.904342,0.903929,0.902962,0.902666,0.902193,0.901361,0.90076,0.90011,0.900093,0.899945,0.899433,0.899403,0.899601,0.89942,0.898986,0.898858,0.898236,0.898636,0.899055,0.898108,0.897928,0.897986,0.898376,0.89772,0.897651,0.89776,0.897574,0.89698,0.897411,0.896863,0.896323,0.896965,0.896544,0.896311,0.896462,0.896242,0.896039,0.895339,0.895552,0.895349,0.895293,0.895212,0.89522,0.894844,0.894523,0.894881,0.895171,0.894391,0.894128,0.894101,0.894439,0.894151,0.894145,0.893758,0.89339,0.893335,0.893341,0.893008,0.893275,0.893021,0.893107,0.893061,0.892328,0.893147,0.893255,0.892644,0.892826,0.892899,0.892971,0.892398,0.892607,0.89266,0.892189,0.89251,0.892124,0.892491,0.892408,0.892176,0.892203,0.892202,0.892343,0.892193,0.892532,0.892136,0.892544,0.892123,0.892221,0.891833,0.891592,0.891403,0.891165,0.891721,0.891846,0.891645,0.891512,0.891779,0.891654,0.891565,0.89186,0.891657,0.891667,0.891767,0.891413,0.891671,0.891747,0.891601,0.891655,0.891868,0.89141,0.891305,0.891106,0.891713,0.891377,0.891389,0.891596,0.891472,0.891448,0.891303,0.890935,0.891604,0.891497,0.89145,0.891079,0.890887,0.891205,0.891049,0.891086,0.891106,0.891193,0.890869,0.891079,0.891242,0.890482,0.89042,0.890341,0.89071,0.890601,0.890547,0.89037,0.890647,0.890502,0.890588,0.890789,0.890421,0.890478,0.890146,0.890622,0.890411,0.890252,0.890747,0.890294,0.890442,0.890132,0.890586,0.890538,0.890602,0.890411,0.890415,0.890285,0.890281,0.890416,0.89039,0.890288,0.89056,0.890995,0.890278,0.890192,0.890172,0.890203,0.890232,0.890151,0.890436,0.890312,0.889975,0.889965,0.890603,0.890103,0.890382,0.890218,0.890059,0.889984,0.889974,0.889564,0.889994,0.889831,0.890115,0.890333,0.890418,0.890285,0.890149,0.890351,0.890123,0.890026,0.890109,0.889952,0.890074,0.889961,0.890223,0.88983,0.89032,0.890148,0.890076,0.889877,0.890347,0.889587,0.889671,0.88979,0.890046,0.889669,0.88989,0.890252,0.890074,0.889844,0.889435,0.890038,0.889261,0.890188,0.890236,0.889628,0.889953,0.890265,0.889797,0.890035,0.889825,0.889535,0.889425,0.88971,0.889936,0.889489,0.889386,0.889581,0.889352,0.889124,0.889288,0.889945,0.889821,0.889103,0.889377,0.889469,0.889235,0.889544,0.889464,0.889776,0.88989,0.889232,0.889886,0.889673,0.889206,0.889483,0.88985,0.889742,0.889741,0.889505,0.889528,0.889531,0.889515,0.889462,0.889822,0.88951,0.88931,0.889448,0.889504,0.889113,0.889857,0.889489,0.889343,0.889523,0.889436,0.889487,0.889635,0.889118,0.889679,0.889468,0.889593,0.889692,0.889344,0.889155,0.889679,0.889357,0.888993,0.889547,0.889343,0.889818,0.889234,0.889215,0.888972,0.889133,0.889105,0.889487,0.888779,0.889246,0.889286,0.889273,0.889205,0.889293,0.889177,0.889163,0.889039,0.889068,0.888869,0.88946,0.889338,0.889487,0.889043,0.889381,0.889082,0.888923,0.888913,0.889223,0.889029,0.888918,0.888963,0.888993,0.889058,0.889293,0.889354,0.88856,0.888951,0.889406,0.888799,0.88929,0.889164,0.889286,0.888706,0.888986,0.889259,0.888675,0.889209,0.889165,0.889173,0.889019,0.888761,0.888814,0.889193,0.889016,0.888918,0.88904,0.888487,0.889097,0.888745,0.88891,0.88888,0.889175,0.888673,0.889016,0.889548,0.888942,0.889002,0.888968,0.888578,0.889223,0.888606,0.888798,0.889163,0.888791,0.888836,0.88887,0.888447,0.889011,0.888855,0.889499,0.888836,0.889352,0.889002,0.888425,0.888507,0.888988,0.888632,0.888897,0.888966,0.889149,0.888547,0.888778,0.888614,0.888565,0.889483,0.8888,0.888586,0.889119,0.888875,0.888732,0.888962,0.888605,0.888692,0.888798,0.889083,0.888845,0.889106,0.888744,0.888677,0.888997,0.888805,0.888932,0.88864,0.888793,0.888256,0.889,0.888867,0.889663,0.888706,0.888696,0.8891,0.889339,0.888779,0.889001,0.88861,0.888959,0.888453,0.888347,0.888669,0.888698,0.888691,0.888913,0.888738,0.888682,0.888446,0.889206,0.888743,0.88867,0.888772,0.888966,0.888703,0.888418,0.888974,0.888739,0.888728,0.888916,0.888737,0.888741,0.888873,0.888296,0.888665,0.88892,0.888976,0.88875,0.888548,0.888563,0.888483,0.888488,0.888764,0.889112,0.888734,0.888666,0.88886,0.888552,0.888664,0.888629,0.889034,0.888319,0.888523,0.88836,0.888964,0.888708,0.888528,0.888504,0.888295,0.888553,0.889005,0.888857,0.888943,0.888287,0.888297,0.888767,0.888168,0.888974,0.888529,0.888296,0.889171,0.888668,0.888467,0.88858,0.888166,0.888485,0.888961,0.888313,0.888575,0.888615,0.888917,0.888606,0.88864]

import matplotlib.pyplot as plt
with plt.style.context("seaborn-white"):
    fig, ax = plt.subplots()
    ax.plot(epochs, losses)
    ax.set_xlabel("Epoch", fontsize=15)
    ax.set_ylabel("Loss", fontsize=15)
    ax.set_title("vamb_N64-L16-0.001_output/log.txt")

image

diff = list()
for i in range(len(losses)-1):
    diff.append(losses[i+1] - losses[i])

with plt.style.context("seaborn-white"):
    fig, ax = plt.subplots()
    ax.plot(epochs[:-1], diff)
    ax.set_xlabel("Epoch", fontsize=15)
    ax.set_ylabel("Loss$_{i+1}$ - Loss$_{i}$", fontsize=15)
    ax.set_title("vamb_N64-L16-0.001_output/log.txt")

image

print(losses[250], "-", losses[200],  "=", losses[250] - losses[200] )
# 0.889581 - 0.889984 = -0.0004030000000000422

Is there any interest in having --early_stopping and --delta_loss parameters? The motivation for this is that the extra compute time necessary to get to 500 epochs isn't worth it since it starts to converge a lot earlier ~200 epochs. For example in this case, if we set --early_stopping 50 and --delta_loss 0.0004it would notice that the hasn't improved bydelta_lossinearly_stopping` epochs so it would cut the algorithm short and continue with the last best epoch up until that point.

I think the most useful way to go about this is if the loss did not decrease by at least delta_loss in early_stopping iterations cumulatively then the algorithm can be cut short.

This would be really helpful when running large datasets and, with what I'm doing, brute force hyperparameter tuning.

simonrasmu commented 3 years ago

I think @jakobnissen wrote a reply to a similar issue but can't find right now. He should chip in.

But we did find that the latent space changes over time even though the loss does not change significantly. Ie, we are, in principle, not interested how much the loss decreases but more the nature of the latent space and how it clusters. For instance, changing the optimiser to Ada Belief (https://arxiv.org/abs/2010.07468) instead of Adam decreased the loss overall, but performed worse in terms of clustering.

jakobnissen commented 3 years ago

Good idea. This was already discussed in #57. For now, I've decreased the number of epochs for version 4.

Simon is right that the latent space continues to change even after loss has flattened out. But we could implement early stopping by tracking the latent space. Every 10 epochs, say, we could measure the latent representation of a few thousand contigs, and compare to the last measurement. If the latent representation hasn't changed much, we could stop the VAE.

I'd need to do a test run where the latent repr. is dumped every few epochs to see how it changes over time, though.

jolespin commented 3 years ago

That is interesting how the clusters change even with slight changes in loss. Also good to know that there other things to consider besides loss with unsupervsied deep learning.

Is there any measure of the latent space you would be able to output as a single metric in the output of each epoch? Also, a timestamp for each epoch could also be helpful for users to have a rough estimate of how long the run might take.