Closed zou3519 closed 1 year ago
Minimum repro for max_pool2d:
import torch
import torch.nn.functional as F
from torch.utils._pytree import tree_flatten
from functorch import jvp, vjp, vmap
kwargs={'kernel_size': 3, 'stride': 2, 'ceil_mode': True, 'padding': 0, 'dilation': 1, 'return_indices': True}
def fn(x):
return F.max_pool2d(x, **kwargs)[0]
x0 = torch.empty_strided([1, 2, 3, 6], (36, 1, 12, 2)).uniform_().cuda()
x1 = torch.empty_strided([1, 2, 1, 3, 2], (12, 6, 6, 2, 1)).cuda()
x2 = torch.rand([1, 2, 3, 6, 2]).cuda()
x3 = torch.rand([1, 2, 1, 3, 2]).cuda()
def push_vjp(primals, cotangents):
_, vjp_fn = vjp(fn, primals)
return vjp_fn(cotangents)
def jvp_of_vjp(x0, x1, x2, x3):
primals = (x0, x1)
tangents = (x2, x3)
primals_out, tangents_out = jvp(push_vjp, primals, tangents)
flat_primals_out, _ = tree_flatten(primals_out)
flat_tangents_out, _ = tree_flatten(tangents_out)
return tuple(flat_primals_out + flat_tangents_out)
_, result = vmap(jvp_of_vjp, (None, -1, -1, -1))(x0, x1, x2, x3)
_, expected0 = jvp_of_vjp(x0, x1.select(-1, 0), x2.select(-1, 0), x3.select(-1, 0))
_, expected1 = jvp_of_vjp(x0, x1.select(-1, 1), x2.select(-1, 1), x3.select(-1, 1))
expected = torch.stack([expected0, expected1])
assert torch.allclose(result, expected)
Even smaller repro for max_pool2d:
import torch
import torch.nn.functional as F
from torch.utils._pytree import tree_flatten
from functorch import jvp, vjp, vmap
torch.manual_seed(0)
x0 = torch.empty_strided([1, 2, 3, 6], (36, 1, 12, 2)).uniform_().cuda()
x1 = torch.rand([2, 1, 2, 1, 3]).cuda()
x2 = torch.rand([2, 1, 2, 3, 6]).cuda()
x3 = torch.rand([2, 1, 2, 1, 3]).cuda()
def trace(x0_1, x1_1, x2_1, x3_1):
max_pool2d_with_indices = torch.ops.aten.max_pool2d_with_indices.default(x0_1, [3, 3], [2, 2], [0, 0], [1, 1], True)
getitem_1 = max_pool2d_with_indices[1]; max_pool2d_with_indices = None
max_pool2d_with_indices_backward_1 = torch.ops.aten.max_pool2d_with_indices_backward.default(x3_1, x2_1, [3, 3], [2, 2], [0, 0], [1, 1], True, getitem_1); x3_1 = getitem_1 = None
return max_pool2d_with_indices_backward_1
result = vmap(trace, (None, 0, 0, 0))(x0, x1, x2, x3)
expected0 = trace(x0, x1[0], x2[0], x3[0])
expected1 = trace(x0, x1[1], x2[1], x3[1])
expected = torch.stack([expected0, expected1])
assert torch.allclose(result, expected)
Householder product problem is just that the tolerance needs to be adjusted, so, closing this.