apache / mxnet

Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
https://mxnet.apache.org
Apache License 2.0
20.73k stars 6.81k forks source link

gather_nd vs. take #21195

Open mureva opened 1 year ago

mureva commented 1 year ago

I have an observation and I'm hoping someone can advise.

I have a scenario where I maintain a large table of vectors, so a basic (n,m) array of n vectors of size m. I then have some system that generates indices into this table, and I want to pull rows out of the table at those indices. (more background? A hash table version of NeRF is what we're making).

So, I have a set of indices, and I want to gather the rows out of the table to make use of elsewhere. There's two operators in mxnet that will do the job - gather_nd and take

I could have more than 100k, even 1000k indices:

So... obvious question... is there a way to get the best of both worlds here? The fast forward pass of take, the fast backward pass of gather_nd?

Is there a better operator for gathering rows from the table? I also tried Embedding - on my test it looked like the best of both worlds, but on the real app, was slow on the backward pass.