pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

Silently incorrect behavior caught by test_vmapjvpvjp #1019

Closed zou3519 closed 1 year ago

zou3519 commented 1 year ago
zou3519 commented 1 year ago

Minimum repro for max_pool2d:

import torch
import torch.nn.functional as F
from torch.utils._pytree import tree_flatten

from functorch import jvp, vjp, vmap

kwargs={'kernel_size': 3, 'stride': 2, 'ceil_mode': True, 'padding': 0, 'dilation': 1, 'return_indices': True}

def fn(x):
  return F.max_pool2d(x, **kwargs)[0]

x0 = torch.empty_strided([1, 2, 3, 6], (36, 1, 12, 2)).uniform_().cuda()
x1 = torch.empty_strided([1, 2, 1, 3, 2], (12, 6, 6, 2, 1)).cuda()
x2 = torch.rand([1, 2, 3, 6, 2]).cuda()
x3 = torch.rand([1, 2, 1, 3, 2]).cuda()

def push_vjp(primals, cotangents):
  _, vjp_fn = vjp(fn, primals)
  return vjp_fn(cotangents)

def jvp_of_vjp(x0, x1, x2, x3):
  primals = (x0, x1)
  tangents = (x2, x3)
  primals_out, tangents_out = jvp(push_vjp, primals, tangents)
  flat_primals_out, _ = tree_flatten(primals_out)
  flat_tangents_out, _ = tree_flatten(tangents_out)
  return tuple(flat_primals_out + flat_tangents_out)

_, result = vmap(jvp_of_vjp, (None, -1, -1, -1))(x0, x1, x2, x3)
_, expected0 = jvp_of_vjp(x0, x1.select(-1, 0), x2.select(-1, 0), x3.select(-1, 0))
_, expected1 = jvp_of_vjp(x0, x1.select(-1, 1), x2.select(-1, 1), x3.select(-1, 1))
expected = torch.stack([expected0, expected1])

assert torch.allclose(result, expected)
zou3519 commented 1 year ago

Even smaller repro for max_pool2d:

import torch
import torch.nn.functional as F
from torch.utils._pytree import tree_flatten
from functorch import jvp, vjp, vmap

torch.manual_seed(0)

x0 = torch.empty_strided([1, 2, 3, 6], (36, 1, 12, 2)).uniform_().cuda()
x1 = torch.rand([2, 1, 2, 1, 3]).cuda()
x2 = torch.rand([2, 1, 2, 3, 6]).cuda()
x3 = torch.rand([2, 1, 2, 1, 3]).cuda()

def trace(x0_1, x1_1, x2_1, x3_1):
    max_pool2d_with_indices = torch.ops.aten.max_pool2d_with_indices.default(x0_1, [3, 3], [2, 2], [0, 0], [1, 1], True)
    getitem_1 = max_pool2d_with_indices[1];  max_pool2d_with_indices = None
    max_pool2d_with_indices_backward_1 = torch.ops.aten.max_pool2d_with_indices_backward.default(x3_1, x2_1, [3, 3], [2, 2], [0, 0], [1, 1], True, getitem_1);  x3_1 = getitem_1 = None
    return max_pool2d_with_indices_backward_1

result = vmap(trace, (None, 0, 0, 0))(x0, x1, x2, x3)
expected0 = trace(x0, x1[0], x2[0], x3[0])
expected1 = trace(x0, x1[1], x2[1], x3[1])
expected = torch.stack([expected0, expected1])

assert torch.allclose(result, expected)
zou3519 commented 1 year ago

Householder product problem is just that the tolerance needs to be adjusted, so, closing this.