Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
I have an observation and I'm hoping someone can advise.
I have a scenario where I maintain a large table of vectors, so a basic (n,m) array of n vectors of size m. I then have some system that generates indices into this table, and I want to pull rows out of the table at those indices. (more background? A hash table version of NeRF is what we're making).
So, I have a set of indices, and I want to gather the rows out of the table to make use of elsewhere. There's two operators in mxnet that will do the job - gather_nd and take
I could have more than 100k, even 1000k indices:
at 100k take will do a forward pass in less than 1 ms, but a backward pass will take about 45 ms. Meanwhile, gather_nd will do a forward pass in about 16 ms, and a backward pass in under 1 ms.
At 1000k indices, take is 4 ms fwd, 400 ms bwd, gather is 170 ms fwd, 1ms bwd. .
So... obvious question... is there a way to get the best of both worlds here? The fast forward pass of take, the fast backward pass of gather_nd?
Is there a better operator for gathering rows from the table? I also tried Embedding - on my test it looked like the best of both worlds, but on the real app, was slow on the backward pass.
I have an observation and I'm hoping someone can advise.
I have a scenario where I maintain a large table of vectors, so a basic (n,m) array of n vectors of size m. I then have some system that generates indices into this table, and I want to pull rows out of the table at those indices. (more background? A hash table version of NeRF is what we're making).
So, I have a set of indices, and I want to gather the rows out of the table to make use of elsewhere. There's two operators in mxnet that will do the job -
gather_nd
andtake
I could have more than 100k, even 1000k indices:
take
will do a forward pass in less than 1 ms, but a backward pass will take about 45 ms. Meanwhile,gather_nd
will do a forward pass in about 16 ms, and a backward pass in under 1 ms.take
is 4 ms fwd, 400 ms bwd, gather is 170 ms fwd, 1ms bwd. .So... obvious question... is there a way to get the best of both worlds here? The fast forward pass of
take
, the fast backward pass ofgather_nd
?Is there a better operator for gathering rows from the table? I also tried
Embedding
- on my test it looked like the best of both worlds, but on the real app, was slow on the backward pass.