pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

Data-dependent control flow exploration #257

Open zou3519 opened 2 years ago

zou3519 commented 2 years ago

🚀 Feature

Explore the existence of control flow operators: e.g. functorch.cond, functorch.while_loop, etc.

Motivation

(1) vmap over data-dependent control flow (2) lowering data-dependent control flow to backends.

Pitch

Prototyping can happen in three stages:

Prototyping, especially for use with vmap, is not too difficult. For e.g. functorch.cond, we'll want to add a new operator to PyTorch. PyTorch doesn't support lambdas in its dispatcher but we can easily hack away at this by hiding the PyObject* lambda in a int64_t:

cond(Tensor pred, int true_fn, int false_fn, Tensor arg)

Alternatives

We do something with TorchScript instead.

Sceki commented 2 years ago

Hello @zou3519 ,

I have a function that I would like to map over batches of tensors: I do not need to differentiate or anything, I would just like to use the parallelization capabilities of vmap. However, the functorch library complains (by raising:

vmap: It looks like you're attempting to use a Tensor in some data-dependent control flow. We don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257

as I have several if statements in my function (that I cannot get rid of). Do you know if it is possible to leverage this parallelized function evaluation without incurring these differentiability requirements?

zou3519 commented 2 years ago

@Sceki the requirement for data-dependent-control flow isn't a differentiability requirement; it's a limitation in vmap in that vmap isn't able to "see" Python if statement. (See here for more details)

One workaround is to replace some of your if statements with torch.where. For example, we can replace the following function (which implements a relu):

def relu(x):
  if x > 0:
    return x
  return 0

with:

def relu(x):
  return torch.where(x > 0, x, 0)

What will end up happening during the rewrite is that you'll have to evaluate both the true branch and the false branch of the computation and then use torch.where(predicate, output_of_true_branch, output_of_false_branch) to construct the output.

Sceki commented 2 years ago

@zou3519 Thank you for your reply!

I actually tried that, but I got the same exception:

def relu(x):
    return   torch.where(((x < 0) or (x>0)), 1, 0)

x = torch.ones((3,1))

functorch.vmap(relu)(x)

returned:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [27], in <module>
      2     return   torch.where(((x < 0) or (x>0)), 1, 0)
      4 x = torch.ones((3,1))
----> 6 functorch.vmap(relu)(x)

File ~/miniconda3/envs/orekit_env_py3.8/lib/python3.8/site-packages/functorch/_src/vmap.py:383, in vmap.<locals>.wrapped(*args, **kwargs)
    381 try:
    382     batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)
--> 383     batched_outputs = func(*batched_inputs, **kwargs)
    384     return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
    385 finally:

Input In [27], in relu(x)
      1 def relu(x):
----> 2     return   torch.where(((x < 0) or (x>0)), 1, 0)

RuntimeError: vmap: It looks like you're attempting to use a Tensor in some data-dependent control flow. We don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257 .

UPDATE:

I think I see now: the problem is the or, I should separate that into two separate torch.where! Thanks for the help!

Sceki commented 2 years ago

@zou3519 thanks again for the support!

I have another question: are there workarounds for while loops? I have read here that you are working on a pytorch implementation of the while, in a similar fashion to jax, but I was wondering if I could work around it somehow in the meantime?!

maulberto3 commented 1 year ago

Hi, while doing some test on my own, I encountered the (...) looks like you're attempting to use a Tensor in some data-dependent control flow. runtime error. And indeed, I was doing that. The function that I wanted to use with vmap has if and else, thus, the said runtime error. Would be nice to have this functionality in vmap.

sergiynesterenko90 commented 1 year ago

+1 the equivalent of jax.lax.cond would be amazing. I'm currently using something equivalent to torch.where, but ultimately my torch.vmap solution is about twice as slow as the original vectorized solution. I have a lot of these conditionals, so I suspect that computing all of them and choosing the right one is the reason. The problem is, I want to use torch.vmap upstream, which I can't do with the faster function.

cwindolf commented 1 year ago

Would love to use this as well -- would enable, for instance, running _strong_wolfe line search from torch.optim.lbfgs on a batch of inputs.

jn-tang commented 1 year ago

+1 the equivalent of jax.lax.cond would be amazing. I'm currently using something equivalent to torch.where, but ultimately my torch.vmap solution is about twice as slow as the original vectorized solution. I have a lot of these conditionals, so I suspect that computing all of them and choosing the right one is the reason. The problem is, I want to use torch.vmap upstream, which I can't do with the faster function.

Hi, how did you implement your PyTorch equivalent to torch.where?

hmdolatabadi commented 1 year ago

Hi All,

I have been using the Python Optimal Transport library. I want to define a loss function that iterates over every sample in my batch and calculates the sinkhorn distance for that sample and its ground-truth value. What I was using before was a for-loop:

for i in range(len(P_batch)):
      if i == 0:
         loss = ot.sinkhorn2(P_batch[i].view(-1, 1), Q_batch[i].view(-1, 1), C, epsilon)
      loss += ot.sinkhorn2(P_batch[i].view(-1, 1), Q_batch[i].view(-1, 1), C, epsilon)

but this is way too slow for my application. I was reading through functorch, and apparently I should have been able to use the vmap functionality. But after wrapping my function in vmap, I get this weird error that everyone else is talking about:

RuntimeError: vmap: It looks like you're attempting to use a Tensor in some data-dependent control flow. We don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257.

Does anyone have a workaround?

Cheers.

abhillman commented 3 months ago

+1

zou3519 commented 3 months ago

A "prototype" torch.cond exists today: https://pytorch.org/docs/main/generated/torch.cond.html#torch.cond. Try it out and please let us know how it goes by opening issues over at github.com/pytorch/pytorch. cc @ydwu4

HongLouyemeng commented 2 months ago

A "prototype" torch.cond exists today: https://pytorch.org/docs/main/generated/torch.cond.html#torch.cond. Try it out and please let us know how it goes by opening issues over at github.com/pytorch/pytorch. cc @ydwu4

Hi,Can I avoid this by writing the control flow in .cpp?