pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

Training with functorch causes constant to be embedded in graph and returned #1046

Closed dellis23 closed 1 year ago

dellis23 commented 1 year ago

I'm working on a simple training example with functorch, and the graph generated by make_fx seems to be embedding the passed in arguments and returning them rather than actually taking in the arguments from the user.

Here's the repro:

import functorch
import torch

def forward(w, b, X):
    return torch.matmul(X, w) + b

def loss_fn(w, b, X, y):
    err = forward(w, b, X) - y
    return torch.mean(torch.square(err))

grad_fn = functorch.grad(loss_fn, argnums=(0, 1))

def update(w, b, grad_w, grad_b):
    updated_w = torch.tensor(list(map(
        lambda x: x[0] - 0.05 * x[1], zip(w, grad_w))))
    updated_b = torch.tensor(b - 0.05 * grad_b)
    return updated_w, updated_b

def train(w, b, X, y):
    grad_w, grad_b = grad_fn(w, b, X, y)
    return update(w, b, grad_w, grad_b)

def main():
    dummy_args = (torch.tensor([1.0, 2.0, 3.0]),
            torch.tensor(1.0),
            torch.tensor([[1.0, 1.0, 1.0]]),
            torch.tensor([1.0]))
    graph = functorch.make_fx(train)(*dummy_args)

    # Call with same args sent to `make_fx`
    print(graph(*dummy_args))

    # Call again with new weights
    print(graph(torch.tensor([10.0, 10.0, 10.0]), *dummy_args[1:]))

    # Show the graph module
    print(graph)

if __name__ == "__main__":
    main()

And here's the output:

(tensor([0.4000, 1.4000, 2.4000]), tensor(0.4000))
(tensor([0.4000, 1.4000, 2.4000]), tensor(-2.))
train()

def forward(self, w_1, b_1, X_1, y_1):
    x_1 = X_1
    mv = torch.ops.aten.mv.default(x_1, w_1)
    add = torch.ops.aten.add.Tensor(mv, b_1);  mv = None
    sub = torch.ops.aten.sub.Tensor(add, y_1);  add = y_1 = None
    pow_1 = torch.ops.aten.pow.Tensor_Scalar(sub, 2)
    mean = torch.ops.aten.mean.default(pow_1);  pow_1 = None
    ones_like = torch.ops.aten.ones_like.default(mean, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False, memory_format = torch.preserve_format);  mean = None
    expand = torch.ops.aten.expand.default(ones_like, [1]);  ones_like = None
    div = torch.ops.aten.div.Scalar(expand, 1);  expand = None
    pow_2 = torch.ops.aten.pow.Tensor_Scalar(sub, 1.0);  sub = None
    mul = torch.ops.aten.mul.Scalar(pow_2, 2.0);  pow_2 = None
    mul_1 = torch.ops.aten.mul.Tensor(div, mul);  div = mul = None
    sum_1 = torch.ops.aten.sum.default(mul_1)
    t = torch.ops.aten.t.default(x_1);  x_1 = None
    mv_1 = torch.ops.aten.mv.default(t, mul_1);  t = mul_1 = None
    unbind = torch.ops.aten.unbind.int(w_1);  w_1 = None
    getitem = unbind[0]
    getitem_1 = unbind[1]
    getitem_2 = unbind[2];  unbind = None
    unbind_1 = torch.ops.aten.unbind.int(mv_1);  mv_1 = None
    getitem_3 = unbind_1[0]
    getitem_4 = unbind_1[1]
    getitem_5 = unbind_1[2];  unbind_1 = None
    mul_2 = torch.ops.aten.mul.Tensor(getitem_3, 0.05);  getitem_3 = None
    sub_1 = torch.ops.aten.sub.Tensor(getitem, mul_2);  getitem = mul_2 = None
    mul_3 = torch.ops.aten.mul.Tensor(getitem_4, 0.05);  getitem_4 = None
    sub_2 = torch.ops.aten.sub.Tensor(getitem_1, mul_3);  getitem_1 = mul_3 = None
    mul_4 = torch.ops.aten.mul.Tensor(getitem_5, 0.05);  getitem_5 = None
    sub_3 = torch.ops.aten.sub.Tensor(getitem_2, mul_4);  getitem_2 = mul_4 = None
    _tensor_constant0 = self._tensor_constant0
    lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0);  _tensor_constant0 = None
    mul_5 = torch.ops.aten.mul.Tensor(sum_1, 0.05);  sum_1 = None
    sub_4 = torch.ops.aten.sub.Tensor(b_1, mul_5);  b_1 = mul_5 = None
    detach = torch.ops.aten.detach.default(sub_4);  sub_4 = None
    detach_1 = torch.ops.aten.detach.default(detach);  detach = None
    _to_copy = torch.ops.aten._to_copy.default(detach_1, dtype = torch.float32, device = device(type='cpu'));  detach_1 = None
    return (lift_fresh_copy, _to_copy)

# To see more debug info, please use `graph_module.print_readable()

Note the same values being returned regardless of what is passed in as an argument. Also note the return of the embedded constant in the module's graph:

    _tensor_constant0 = self._tensor_constant0
    lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0);  _tensor_constant0 = None
    #...
    return (lift_fresh_copy, _to_copy)
dellis23 commented 1 year ago

Got an answer on slack: https://pytorch.slack.com/archives/C02DVE5MAFR/p1665578913746529

For posterity, this seems to have been caused by my call to torch.tensor inside my update function.