pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.39k stars 102 forks source link

Actually actionable list of batching rules to write #240

Open zou3519 opened 2 years ago

zou3519 commented 2 years ago

For each of the items here, we should make sure all compositions (vmap, vmap x vjp) have a batching rule. All of these items should be actionable (in that it is possible to write a batching rule and we are not blocked on functionalization, which is coming soon).

Note: you may need to write an OpInfo for the operator if it doesn't exist already or wait for one to be added. A lot of folks are adding OpInfos right now, so if the OpInfo doesn't exist please ask first to see if someone is working on it.

Note: if any of the operations decompose into in-place operations, then we need functionalization to handle them. I think I've already filtered out all of those, but please check me on that.

Parcel 1: top nn.functional. and top torch. foo

Parcel 2: new_blah

Parcel 3: linalg things

Parcel 4:

vfdev-5 commented 2 years ago

@zou3519 nn.functional.pad with circular option requires to do a copy: out[..., out_d0:out_d1] = input[..., in_d0:in_d1] and thus there is the following error is raised:

>           out[..., out_d0:out_d1] = input[..., in_d0:in_d1]                                                                                                                               
E           RuntimeError: vmap: aten::copy_(self, *extra_args) is not possible because there exists a Tensor `other` in extra_args that has more elements than `self`. This happened due to `other` being vmapped over but `self` not being vmapped over at level 2. Please try to use out-of-place operators instead of aten::copy_. If said operator is being called inside the PyTorch framework, please file a bug report instead. 

Can we do something with that ?

zou3519 commented 2 years ago

Aha. Nope, we can't do anything about that until functionalization is in, good catch.

Padarn commented 2 years ago

Hi @zou3519, the "forward pass only" ops above means that the vjp and related operators require the functionalization too?

zou3519 commented 2 years ago

Hi @zou3519, the "forward pass only" ops above means that the vjp and related operators require the functionalization too?

Yes, "forward pass only" means we should only try to get the vmap tests passing and none of the vjp/grad/compositions of {vjp, grad} tests.

vfdev-5 commented 2 years ago

@kshitij12345 on which tasks from Parcel 2 you are working on and plan to work on ?

I can start working on _cdist_forward, _cdist_backward

kshitij12345 commented 2 years ago

@vfdev-5 I think I'll be picking diagonal_scatter next. Go ahead with _cdist_forward, _cdist_backward

vfdev-5 commented 2 years ago

I'll take torch.addr and linalg.eig and cholesky_solve and _lu_with_info next

vfdev-5 commented 2 years ago

Let's hold on on svd batch rule as there is ongoing refactoring which may fix CPU/CUDA discrepancy test issue: https://github.com/pytorch/pytorch/pull/69827

vfdev-5 commented 2 years ago

To close this issue, it remains to finalize parcel 2:

and in parcel 4:

zou3519 commented 2 years ago

There's always more batching rules to write, I'll put up a new issue for them later :)

lezcano commented 2 years ago

Note that _lu_with_info is not a thing any more. Now we have linalg_lu_factor and linalg_lu_factor_ex cf. https://github.com/pytorch/pytorch/pull/66933

vfdev-5 commented 2 years ago

linalg_lu_factor

@Lezcano thanks for the update ! I see that _lu_with_info is marked to be deprecated, so torch.lu will be deprecated as well ?

lezcano commented 2 years ago

It will indeed. And that's a good reminder for me to put up a PR doing so :D

vfdev-5 commented 2 years ago

@zou3519 can we update description list with with was done. I think we can remove Parcel 4 from here and create new issue for that if needed. What remains here is to sync and merge householder product PR (#322), cc @kshitij12345 .

lezcano commented 2 years ago

Fwiw, following up on the point above on deprecating torch.lu: https://github.com/pytorch/pytorch/pull/73804 https://github.com/pytorch/pytorch/pull/73806

zou3519 commented 2 years ago

@zou3519 can we update description list with with was done. I think we can remove Parcel 4 from here and create new issue for that if needed. What remains here is to sync and merge householder product PR (#322)

Yes I'll create another issue soon