pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.39k stars 102 forks source link

Tensor.nonzero_static fails on GPU inside torch.func.vmap #1145

Open Tendocat opened 3 months ago

Tendocat commented 3 months ago
def foo(a:torch.Tensor, b:torch.Tensor):
    bool_mat = (a & b)
    print(bool_mat)
    return bool_mat.nonzero_static(size=bool_mat.shape[0])
a = torch.zeros((5,5), dtype=torch.bool, device="cuda")
b = a.clone()
res = torch.func.vmap(foo)(a, b)

ErrorCode:

NotImplementedError: Could not run 'aten::nonzero_static' with arguments from the 'CUDA' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::nonzero_static' is only available for these backends: [CPU, Meta, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMTIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradMeta, AutogradNestedTensor, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

nonzero_static is not working on CUDA device, there is also a warning on CPU.

It would be interresting to implement this operation as it is the easiest placeholder of the classic nonzero() that is not supported in vmap for its dynamic range.

Tendocat commented 3 months ago

256