pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

Make vmap tests use dtype `any_one` #1092

Closed samdow closed 1 year ago

samdow commented 1 year ago

In #1069, @kshitij12345 smartly pointed out that it's disturbing that these batch rules aren't caught by test_op_has_batch_rule. From looking at it, the bitwise ops in particular aren't being tested because the only allowed_dtype is torch.float

Steps

  1. First, please update both test_vmap and test_op_has_batch_rule to have their allowed_dtypes (in the @ops decorator) be OpDTypes.any_one instead of torch.float32
  2. We expect this to lead to new failures. Please update the corresponding xfail list for the test. i. In the case of test_op_has_batch_rule, if the failure looks to occur on an in-place function, please try first to only add it the inplace_failures list. If this does not work, you can xfail it
kshitij12345 commented 1 year ago

Fixed in https://github.com/pytorch/pytorch/pull/91355