Open zou3519 opened 2 years ago
Hello @zou3519 ,
I have a function that I would like to map over batches of tensors: I do not need to differentiate or anything, I would just like to use the parallelization capabilities of vmap. However, the functorch library complains (by raising:
vmap: It looks like you're attempting to use a Tensor in some data-dependent control flow. We don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257
as I have several if statements in my function (that I cannot get rid of). Do you know if it is possible to leverage this parallelized function evaluation without incurring these differentiability requirements?
@Sceki the requirement for data-dependent-control flow isn't a differentiability requirement; it's a limitation in vmap in that vmap isn't able to "see" Python if
statement. (See here for more details)
One workaround is to replace some of your if statements with torch.where. For example, we can replace the following function (which implements a relu):
def relu(x):
if x > 0:
return x
return 0
with:
def relu(x):
return torch.where(x > 0, x, 0)
What will end up happening during the rewrite is that you'll have to evaluate both the true branch and the false branch of the computation and then use torch.where(predicate, output_of_true_branch, output_of_false_branch) to construct the output.
@zou3519 Thank you for your reply!
I actually tried that, but I got the same exception:
def relu(x):
return torch.where(((x < 0) or (x>0)), 1, 0)
x = torch.ones((3,1))
functorch.vmap(relu)(x)
returned:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Input In [27], in <module>
2 return torch.where(((x < 0) or (x>0)), 1, 0)
4 x = torch.ones((3,1))
----> 6 functorch.vmap(relu)(x)
File ~/miniconda3/envs/orekit_env_py3.8/lib/python3.8/site-packages/functorch/_src/vmap.py:383, in vmap.<locals>.wrapped(*args, **kwargs)
381 try:
382 batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)
--> 383 batched_outputs = func(*batched_inputs, **kwargs)
384 return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
385 finally:
Input In [27], in relu(x)
1 def relu(x):
----> 2 return torch.where(((x < 0) or (x>0)), 1, 0)
RuntimeError: vmap: It looks like you're attempting to use a Tensor in some data-dependent control flow. We don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257 .
UPDATE:
I think I see now: the problem is the or
, I should separate that into two separate torch.where
! Thanks for the help!
@zou3519 thanks again for the support!
I have another question: are there workarounds for while
loops? I have read here that you are working on a pytorch implementation of the while
, in a similar fashion to jax
, but I was wondering if I could work around it somehow in the meantime?!
Hi, while doing some test on my own, I encountered the (...) looks like you're attempting to use a Tensor in some data-dependent control flow.
runtime error. And indeed, I was doing that. The function that I wanted to use with vmap has if and else, thus, the said runtime error. Would be nice to have this functionality in vmap.
+1 the equivalent of jax.lax.cond would be amazing. I'm currently using something equivalent to torch.where, but ultimately my torch.vmap solution is about twice as slow as the original vectorized solution. I have a lot of these conditionals, so I suspect that computing all of them and choosing the right one is the reason. The problem is, I want to use torch.vmap upstream, which I can't do with the faster function.
Would love to use this as well -- would enable, for instance, running _strong_wolfe
line search from torch.optim.lbfgs
on a batch of inputs.
+1 the equivalent of jax.lax.cond would be amazing. I'm currently using something equivalent to torch.where, but ultimately my torch.vmap solution is about twice as slow as the original vectorized solution. I have a lot of these conditionals, so I suspect that computing all of them and choosing the right one is the reason. The problem is, I want to use torch.vmap upstream, which I can't do with the faster function.
Hi, how did you implement your PyTorch equivalent to torch.where
?
Hi All,
I have been using the Python Optimal Transport
library. I want to define a loss function that iterates over every sample in my batch and calculates the sinkhorn
distance for that sample and its ground-truth value. What I was using before was a for-loop:
for i in range(len(P_batch)):
if i == 0:
loss = ot.sinkhorn2(P_batch[i].view(-1, 1), Q_batch[i].view(-1, 1), C, epsilon)
loss += ot.sinkhorn2(P_batch[i].view(-1, 1), Q_batch[i].view(-1, 1), C, epsilon)
but this is way too slow for my application. I was reading through functorch, and apparently I should have been able to use the vmap
functionality. But after wrapping my function in vmap
, I get this weird error that everyone else is talking about:
RuntimeError: vmap: It looks like you're attempting to use a Tensor in some data-dependent control flow. We don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257.
Does anyone have a workaround?
Cheers.
+1
A "prototype" torch.cond exists today: https://pytorch.org/docs/main/generated/torch.cond.html#torch.cond. Try it out and please let us know how it goes by opening issues over at github.com/pytorch/pytorch. cc @ydwu4
A "prototype" torch.cond exists today: https://pytorch.org/docs/main/generated/torch.cond.html#torch.cond. Try it out and please let us know how it goes by opening issues over at github.com/pytorch/pytorch. cc @ydwu4
Hi,Can I avoid this by writing the control flow in .cpp?
🚀 Feature
Explore the existence of control flow operators: e.g.
functorch.cond
,functorch.while_loop
, etc.Motivation
(1) vmap over data-dependent control flow (2) lowering data-dependent control flow to backends.
Pitch
Prototyping can happen in three stages:
Prototyping, especially for use with vmap, is not too difficult. For e.g. functorch.cond, we'll want to add a new operator to PyTorch. PyTorch doesn't support lambdas in its dispatcher but we can easily hack away at this by hiding the PyObject* lambda in a int64_t:
Alternatives
We do something with TorchScript instead.