ElArkk / jax-unirep

Reimplementation of the UniRep protein featurization model.
GNU General Public License v3.0
104 stars 31 forks source link

Evotune with "iters + batches" or "epochs"? #40

Closed ericmjl closed 4 years ago

ericmjl commented 4 years ago

I'm wondering whether evotuning should work on iterations with fixed batch sizes (to make visual progress a bit faster) or on epochs. From others' empirical experience, iterations with fixed batch sizes allows for faster convergence, though I'm more concerned with model ease-of-use instead of convergence speed, and visual progress is important for that.

For batching, we can totally handle this using size-batching instead of random batching. Randomly sample a size, then randomly sample K sequences, then pass them through the model.

ivanjayapurna commented 4 years ago

Hi @ericmjl , I did think it was weird that batch sizes were non uniform when first reading through the implementation (hence why I wrote that function to print them out, to get an intuitive idea of what they look like per each dataset). Here are a couple of my thoughts:

  1. For ease-of-use I think controlling epochs is better - it just makes more intuitive sense, and gives easier comparisons.
  2. For visualizing progress however, especially with a large number of input sequences, looking loss per batch within the epoch is much better - it could be the case that after 1 epoch you're done and have reached a minima in which case your visualization would be pretty terrible.
  3. Correct me if I'm wrong, but the reason the batches are so weirdly distributed right now is that they are batched by sequence length as per your report: "By preprocessing strings in batches of the same size, and by keeping track of the original ordering, then we could (1) avoid compilation penalty, and (2) vectorize much of the tensor operations over the sample axis, before returning the representation vectors in the original order of the sequences." Would switching to fixed size batching mean slower compute / require padding sequences?
ericmjl commented 4 years ago

Hey @ivanjayapurna, on that last point, yes, it is right that the batches are weirdly distributed because of the batching by length. We did this to explicitly avoid padding sequences, as we weren't really sure what would happen to the reps reimplementation if we did padding. (Plus the attraction of an RNN is to avoid the need to do any multiple sequence alignment.)

In terms of re-doing batching to fit with our size paradigm, I was thinking of the following:

  1. Randomly sample a length from the sequence length distribution.
  2. Randomly sample K sequences, such that K = min(K, len(batch)).
  3. Perform gradient update.

Doing this would be pseudo-gradient descent, as we wouldn't be uniformly sampling from the dataset.

From what I know empirically, purely random batching is supposed to help with convergence speed, and others at work whose experience I trust have communicated to me that it helps with generalization as well, as it helps with stochastic exploration of the loss landscape, a pseudo Bayesian way of operating, I guess.

Definitely agree with you, though, on having an end-user control epochs, especially if in the docstring we make clear that "epochs" = "number of complete passes over the dataset".

I think there's something that might be do-able with allowing end-users to control "number of epochs", while we estimate number of epochs completed while using length-batched stochastic gradient descent under-the-hood, done for the benefits of stochastic gradient descent.

ivanjayapurna commented 4 years ago

Hi @ericmjl , I like specifying epochs then calculating batches under the hood as thats a calculation people would do themselves to set num_batches anyway. I see 2 potential issues with your proposed solution:

1.) If we're randomly sampling, that means we're not really doing a "full pass" on all the sequences anymore right? The concept of an "epoch" no longer is relevant in this case if I'm understanding correctly.

2.) In a real data set the distribution of lengths is most likely going to be uneven (probably gaussian like actually). The bulk of the batches, other than the ones near the mean length, are likely to contain a very low number of sequences --> unless the "K" variable is very small our batches will end up being uneven anyway. Here is an example of a real data set I've been playing with:

Number of batches: 566, Average batch length: 113.55123901367188, Batch lengths: [ 1 1 2 3 3 4 1 5 4 3 6 6 1 10 2 7 5 6 5 5 2 6 8 5 5 9 11 7 9 11 10 6 12 10 16 11 9 13 10 11 13 19 13 21 13 16 14 17 17 18 14 10 15 20 19 20 21 12 21 18 16 19 19 19 10 19 19 18 20 21 24 18 31 17 20 24 17 19 24 25 23 26 17 25 27 25 32 24 35 25 24 24 22 31 24 40 22 24 36 20 15 27 27 25 24 28 13 18 22 25 30 25 31 25 25 56 24 22 27 19 33 22 26 28 33 40 22 36 44 41 33 26 29 33 30 41 48 26 32 29 36 45 30 40 51 36 39 40 35 46 66 38 65 56 50 101 95 113 333 128 146 263 246 146 139 183 234 167 160 323 390 377 278 244 201 270 258 191 242 214 276 224 290 262 292 272 312 346 390 502 506 518 581 728 529 424 522 550 611 653 621 826 477 484 501 532 526 524 633 555 684 796 665 707 1264 951 803 593 663 591 853 492 393 429 431 331 404 281 294 352 430 504 608 344 459 351 297 382 339 406 376 254 227 338 396 311 332 341 312 287 359 286 366 293 360 255 306 262 280 334 613 407 383 485 1436 513 315 323 356 281 346 272 317 282 230 215 212 248 178 186 169 152 126 122 131 156 137 130 160 128 173 134 97 124 123 113 104 117 139 132 136 92 115 122 145 217 151 112 122 184 119 273 98 74 98 71 76 75 73 88 89 83 84 82 74 108 69 53 53 54 38 45 48 40 47 59 54 44 56 56 55 40 52 45 38 43 55 36 44 41 34 33 35 31 48 33 47 33 36 36 31 28 13 50 29 19 25 20 18 16 15 23 21 19 21 24 64 33 71 60 91 78 52 66 91 94 72 29 34 31 22 22 29 21 21 35 8 17 16 21 31 24 35 24 21 30 17 40 27 28 31 21 16 33 27 29 14 24 33 25 52 36 21 34 29 63 28 19 25 20 18 29 23 22 56 27 49 34 49 24 18 25 24 27 29 22 13 26 11 15 8 10 13 10 12 17 14 9 14 9 9 13 12 21 18 22 17 28 43 55 25 14 11 20 12 10 25 10 16 19 26 14 8 19 16 10 15 7 8 10 6 4 4 8 5 9 7 9 4 3 6 2 7 7 11 6 9 6 3 10 3 7 10 10 7 3 3 8 6 1 5 8 6 3 4 5 6 6 6 7 4 4 18 5 10 2 6 2 2 6 6 6 4 8 4 4 5 4 2 5 2 3 1 2 1 2 1 1 3 2 5 3 3 6 2 1]

ElArkk commented 4 years ago

If we still want to preserve the concept of epochs, while sampling randomly from the input sequences and doing length batching, I can see two ways:

I don't really have experience here though! I don't have intuition on whether e.g. starting to pad sequences in favour of easier fixed batch size sampling would actually help with faster convergence in the end.