Open pl-fuchs opened 3 years ago
The "sequential_vmap" concept seems like it would be easy to implement given a more general custom batching interface, which is something that @mattjj and I have considered before.
Short of that, it might be worth considering offering the special case, by writing a single higher-order primitive whose batching rule is essentially a call to lax.map
.
@froystig Would breaking a vmap
into a sequential_vmap
alleviate memory usage?
The original request is about modifying the behavior of a function under vmap
. If you're asking about replacing the use of vmap
entirely with a sequential map, then yes: we have lax.map
and using it might require less memory relatively. Did I read your question correctly?
@froystig Apologies, I see that this was not the right place to ask that question -- You did answer my question, though, so thanks
No problem. Glad that helped!
The "sequential_vmap" concept seems like it would be easy to implement given a more general custom batching interface, which is something that @mattjj and I have considered before.
More general "custom batching" could be interesting for use-cases like host_callback.call
, where some external library that may support its own parallelism strategies. E.g., I think this could be useful for @ianwilliamson's MEEP wrapper: https://github.com/NanoComp/meep/pull/1569
Short of that, it might be worth considering offering the special case, by writing a single higher-order primitive whose batching rule is essentially a call to
lax.map
.
This could definitely be a good place to start, even if only as the first step towards the general solution.
This could definitely be a good place to start, even if only as the first step towards the general solution.
I tend to agree, like having written linear_call
as a first step towards custom transposition. Perhaps worth trying to see what it surfaces.
I am writing some longer algorithms which I like to vmap. I stumbled over some problems when combining vmap with the
host_callback
module and thelax.cond
function:host_callback.call
functionI think a simple solution would be to implement a
stop_vmap
orsequential_vmap
decorator. This decorator would define a batching rule, such thatwould be the same as writing
The advantage of the decorator would be that
some_fun
could be used inside a much bigger vmapped function.