Closed matt-gardner closed 7 years ago
I just tried increasing the batch size to 60, and it's now ~1.5 hours per epoch, which is comparable to Min's. I'm afraid it might run out of GPU memory when it gets to some of the longer instances, though - we'll see.
And.... it ran out of memory. Too bad. But at least that hints that if we can bring down memory usage a bit, somehow, we should have an implementation that's just as fast as Min's, but a whole lot cleaner and more extensible.
I'm afraid it might run out of GPU memory when it gets to some of the longer instances, though - we'll see.
What max length are you using? or are you not truncating?
Nice to see this and those are some impressive speedups!
RE: OOM. A few easy things to do to reduce memory pressure:
I wasn't truncating at all, because the code isn't set up yet to do that well with dynamic padding.
But, here's a nice idea: use adaptive batch sizes. When the instances are small, increase the batch size, and when they are large, decrease the batch size. We could do this really easily with some simple heuristics right now, but doing it well, and automatically, seems pretty challenging.
I wasn't truncating at all, because the code isn't set up yet to do that well with dynamic padding.
I'd guess that this (the original BiDAF code truncates aggressively) + making sure embeddings are on the CPU, at least, should solve our memory issues...
I was actually surprised that this worked; I didn't think
MatrixAttention
was going to be compatible, but it worked, with some modification. This depends on a couple of previous PRs, so we should wait to review it until those other two are merged, but I was excited that I got this working, so I'm opening a PR now. This again gives a 4-5x speed up in running time, plus a much faster start-up time, because you don't have to do all of the padding up front. Before this, I was seeing ~37-38 hours per epoch, and with this, it's 7-8 hours per epoch. Still pretty slow, but much better than before. I'm hoping that with TF optimizers and some more profiling work, we can get this down lower, to be more comparable with Min's implementation. I think it's still ~3x slower than Min's, maybe as much as 5x.