pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

grad+jvp causes INTERNAL ASSERT FAILED #939

Closed cyyever closed 1 year ago

cyyever commented 1 year ago

As the title says, I tried to combine grad and jvp and then pytorch said

catch exception:(new_grad_base.is_floating_point() || new_grad_base.is_complex()) && (self_base.is_floating_point() || self_base.is_complex()) INTERNAL ASSERT FAILED at "../torch/csrc/autograd/autograd_meta.cpp":132, please report a bug to PyTorch. Expected both tensor and its forward grad to be floating point or complex, but tensor is UNKNOWN_SCALAR and grad is Float

I compile functorch and pytorch from master versions yesterday so this should be an unresolved issue. I managed to come up with a minimal script:

import collections

import torch
import torch.nn as nn
import torch.nn.functional as F

def cat_tensors_to_vector(tensors: list) -> torch.Tensor:
    return nn.utils.parameters_to_vector([t.reshape(-1) for t in tensors])

from functorch import grad, jacfwd, jacrev, jvp, make_functional

class LeNet5(nn.Module):
    """
    Input - 1x32x32
    C1 - 6@28x28 (5x5 kernel)
    tanh
    S2 - 6@14x14 (2x2 kernel, stride 2) Subsampling
    C3 - 16@10x10 (5x5 kernel, complicated shit)
    tanh
    S4 - 16@5x5 (2x2 kernel, stride 2) Subsampling
    C5 - 120@1x1 (5x5 kernel)
    F6 - 84
    tanh
    F7 - 10 (Output)
    """

    input_size = (32, 32)

    def __init__(self, input_channels=1):
        super().__init__()
        self.input_channels = input_channels
        self.convnet = nn.Sequential(
            collections.OrderedDict(
                [
                    ("c1", nn.Conv2d(self.input_channels, 6, kernel_size=5)),
                    ("relu1", nn.ReLU()),
                    ("s2", nn.MaxPool2d(kernel_size=2, stride=2)),
                    ("c3", nn.Conv2d(6, 16, kernel_size=5)),
                    ("relu3", nn.ReLU()),
                    ("s4", nn.MaxPool2d(kernel_size=2, stride=2)),
                    ("c5", nn.Conv2d(16, 120, kernel_size=5)),
                    ("relu5", nn.ReLU()),
                ]
            )
        )

        self.fc = nn.Sequential(
            collections.OrderedDict(
                [
                    ("f6", nn.Linear(120, 84)),
                    ("relu6", nn.ReLU()),
                    ("f7", nn.Linear(84, 10)),
                ]
            )
        )

    def forward(self, x):
        x = x.reshape(-1, 1, 32, 32)
        output = self.convnet(x)
        output = output.view(x.size(0), -1)
        output = self.fc(output)
        return output

model = LeNet5()
fun, param = make_functional(model)

def loss_fun(param, input_tensor):
    target = torch.LongTensor([1])
    return F.cross_entropy(fun(param, input_tensor), target)

def grad_f(input_tensor):
    return grad(loss_fun)(param, input_tensor)

input_tensor = torch.rand((1, 32, 32))
vector = torch.ones_like(input_tensor)
print(jvp(grad_f, (input_tensor,), (vector,)))

But it fails with

RuntimeError: Cannot report itemsize of Tensor that doesn't have initialized dtype (e.g., caffe2::Tensor x(CPU), prior to calling mutable_data<T>() on x)
zou3519 commented 1 year ago

Thanks for the bug report and the repro script. I was able to reproduce this and extracted a smaller repro. Something is wrong from our side, we'll look into it

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd.forward_ad as fwAD

conv_weight = torch.randn(6, 1, 30, 30)

def model(weights, x):
    conv_weight = weights
    x = F.conv2d(x, conv_weight)
    x = x.view(x.size(0), -1)
    return x

def loss_fun(param, input_tensor):
    target = torch.LongTensor([1])
    out = model(param, input_tensor)
    return F.log_softmax(out).sum()

input_tensor = torch.rand((1, 1, 32, 32))
vector = torch.ones_like(input_tensor)

from functorch import grad, jacfwd, jacrev, jvp, make_functional

def grad_f(input_tensor):
    return grad(loss_fun)(conv_weight, input_tensor)

print(jvp(grad_f, (input_tensor,), (vector,)))
zou3519 commented 1 year ago

I have root-caused this to https://github.com/pytorch/pytorch/issues/81111

zou3519 commented 1 year ago

@cyyever this has been fixed in PyTorch and will be in the next release. If you want to use it earlier, please try a PyTorch nightly (and build functorch from source)

zou3519 commented 1 year ago

Closing because this has been resolved