pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

Default Partition: Backward module does not contain computation of the parameters' gradients. #1014

Open ConnollyLeon opened 1 year ago

ConnollyLeon commented 1 year ago

I am trying to use functorch to train a model in a more JAX-like way. I use the aot_function to get a forward graph module and a backward graph module, but find out that in the backward module, it does not contain the computation of the parameters' gradients.

After reading the source code, I think the below function erases the corresponding computation part, as the parameter gradient calculation is irrelevant to the output of the backward module.

https://github.com/pytorch/functorch/blob/6c3b57f3a3fd54a2f3e3db12c2059669112bed6c/functorch/_src/partitioners.py#L94

I think it would be better for you to offer an api that can capture these important computation in training a neural network. Would you develop this in the future?

Here is the backward module of alexnet that I generates. As you can see, it does not involves the computation of the gradients of parameters.


import torch
from torch.nn import *
class alexnet_backward(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer('_tensor_constant3', torch.empty([1000, 4096], dtype=torch.float32))
        self.register_buffer('_tensor_constant4', torch.empty([4096, 4096], dtype=torch.float32))
        self.register_buffer('_tensor_constant5', torch.empty([4096, 9216], dtype=torch.float32))
        self._param_constant8 = torch.nn.Parameter(torch.empty([256, 256, 3, 3], dtype=torch.float32))
        self._param_constant6 = torch.nn.Parameter(torch.empty([256, 384, 3, 3], dtype=torch.float32))
        self._param_constant4 = torch.nn.Parameter(torch.empty([384, 192, 3, 3], dtype=torch.float32))
        self._param_constant2 = torch.nn.Parameter(torch.empty([192, 64, 5, 5], dtype=torch.float32))
        self._param_constant0 = torch.nn.Parameter(torch.empty([64, 3, 11, 11], dtype=torch.float32))
        self.load_state_dict(torch.load(r'alexnet_backward/state_dict.pt'))

    def forward(self, primals_1, relu_, getitem, relu__1, getitem_2, relu__2, relu__3, relu__4, getitem_4, div_, relu__5, div__1, relu__6, tangents_1):
        detach = torch.ops.aten.detach(relu_)
        max_pool2d_with_indices = torch.ops.aten.max_pool2d_with_indices(relu_, [3, 3], [2, 2])
        getitem_1 = max_pool2d_with_indices[1];  max_pool2d_with_indices = None
        detach_1 = torch.ops.aten.detach(relu__1)
        max_pool2d_with_indices_1 = torch.ops.aten.max_pool2d_with_indices(relu__1, [3, 3], [2, 2])
        getitem_3 = max_pool2d_with_indices_1[1];  max_pool2d_with_indices_1 = None
        detach_2 = torch.ops.aten.detach(relu__2)
        detach_3 = torch.ops.aten.detach(relu__3)
        detach_4 = torch.ops.aten.detach(relu__4)
        max_pool2d_with_indices_2 = torch.ops.aten.max_pool2d_with_indices(relu__4, [3, 3], [2, 2])
        getitem_5 = max_pool2d_with_indices_2[1];  max_pool2d_with_indices_2 = None
        detach_5 = torch.ops.aten.detach(relu__5);  relu__5 = None
        detach_6 = torch.ops.aten.detach(relu__6);  relu__6 = None
        _tensor_constant3 = self._tensor_constant3
        mm = torch.ops.aten.mm(tangents_1, _tensor_constant3);  tangents_1 = _tensor_constant3 = None
        detach_7 = torch.ops.aten.detach(detach_6);  detach_6 = None
        threshold_backward = torch.ops.aten.threshold_backward(mm, detach_7, 0);  mm = detach_7 = None
        _tensor_constant4 = self._tensor_constant4
        mm_2 = torch.ops.aten.mm(threshold_backward, _tensor_constant4);  threshold_backward = _tensor_constant4 = None
        mul_2 = torch.ops.aten.mul(mm_2, div__1);  mm_2 = div__1 = None
        detach_8 = torch.ops.aten.detach(detach_5);  detach_5 = None
        threshold_backward_1 = torch.ops.aten.threshold_backward(mul_2, detach_8, 0);  mul_2 = detach_8 = None
        _tensor_constant5 = self._tensor_constant5
        mm_4 = torch.ops.aten.mm(threshold_backward_1, _tensor_constant5);  threshold_backward_1 = _tensor_constant5 = None
        mul_3 = torch.ops.aten.mul(mm_4, div_);  mm_4 = div_ = None
        view_4 = torch.ops.aten.view(mul_3, [1, 256, 6, 6]);  mul_3 = None
        _adaptive_avg_pool2d_backward = torch.ops.aten._adaptive_avg_pool2d_backward(view_4, getitem_4);  view_4 = getitem_4 = None
        max_pool2d_with_indices_backward = torch.ops.aten.max_pool2d_with_indices_backward(_adaptive_avg_pool2d_backward, relu__4, [3, 3], [2, 2], [0, 0], [1, 1], False, getitem_5);  _adaptive_avg_pool2d_backward = relu__4 = getitem_5 = None
        detach_9 = torch.ops.aten.detach(detach_4);  detach_4 = None
        threshold_backward_2 = torch.ops.aten.threshold_backward(max_pool2d_with_indices_backward, detach_9, 0);  max_pool2d_with_indices_backward = detach_9 = None
        _param_constant8_2 = self._param_constant8
        convolution_backward = torch.ops.aten.convolution_backward(threshold_backward_2, relu__3, _param_constant8_2, [256], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]);  threshold_backward_2 = relu__3 = _param_constant8_2 = None
        getitem_6 = convolution_backward[0];  convolution_backward = None
        detach_10 = torch.ops.aten.detach(detach_3);  detach_3 = None
        threshold_backward_3 = torch.ops.aten.threshold_backward(getitem_6, detach_10, 0);  getitem_6 = detach_10 = None
        _param_constant6_2 = self._param_constant6
        convolution_backward_1 = torch.ops.aten.convolution_backward(threshold_backward_3, relu__2, _param_constant6_2, [256], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]);  threshold_backward_3 = relu__2 = _param_constant6_2 = None
        getitem_9 = convolution_backward_1[0];  convolution_backward_1 = None
        detach_11 = torch.ops.aten.detach(detach_2);  detach_2 = None
        threshold_backward_4 = torch.ops.aten.threshold_backward(getitem_9, detach_11, 0);  getitem_9 = detach_11 = None
        _param_constant4_2 = self._param_constant4
        convolution_backward_2 = torch.ops.aten.convolution_backward(threshold_backward_4, getitem_2, _param_constant4_2, [384], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]);  threshold_backward_4 = getitem_2 = _param_constant4_2 = None
        getitem_12 = convolution_backward_2[0];  convolution_backward_2 = None
        max_pool2d_with_indices_backward_1 = torch.ops.aten.max_pool2d_with_indices_backward(getitem_12, relu__1, [3, 3], [2, 2], [0, 0], [1, 1], False, getitem_3);  getitem_12 = relu__1 = getitem_3 = None
        detach_12 = torch.ops.aten.detach(detach_1);  detach_1 = None
        threshold_backward_5 = torch.ops.aten.threshold_backward(max_pool2d_with_indices_backward_1, detach_12, 0);  max_pool2d_with_indices_backward_1 = detach_12 = None
        _param_constant2_2 = self._param_constant2
        convolution_backward_3 = torch.ops.aten.convolution_backward(threshold_backward_5, getitem, _param_constant2_2, [192], [1, 1], [2, 2], [1, 1], False, [0, 0], 1, [True, True, True]);  threshold_backward_5 = getitem = _param_constant2_2 = None
        getitem_15 = convolution_backward_3[0];  convolution_backward_3 = None
        max_pool2d_with_indices_backward_2 = torch.ops.aten.max_pool2d_with_indices_backward(getitem_15, relu_, [3, 3], [2, 2], [0, 0], [1, 1], False, getitem_1);  getitem_15 = relu_ = getitem_1 = None
        detach_13 = torch.ops.aten.detach(detach);  detach = None
        threshold_backward_6 = torch.ops.aten.threshold_backward(max_pool2d_with_indices_backward_2, detach_13, 0);  max_pool2d_with_indices_backward_2 = detach_13 = None
        _param_constant0_2 = self._param_constant0
        convolution_backward_4 = torch.ops.aten.convolution_backward(threshold_backward_6, primals_1, _param_constant0_2, [64], [4, 4], [2, 2], [1, 1], False, [0, 0], 1, [True, True, True]);  threshold_backward_6 = primals_1 = _param_constant0_2 = None
        getitem_18 = convolution_backward_4[0];  convolution_backward_4 = None
        return [getitem_18]

Another question is that, I also tried to output the code of joint_forward_backward. The joint_forward_backward GraphModule contains the computation of weight gradients. But I find out that the parameters of Linear layer is redefined in the __init__ method. As you can see in the below code. _tensor_constant0 and _tensor_connstant5 are supposed to be the one and the same parameter, but with different shape. _tensor_constant5 is the transpose of _tensor_constant0. But here it seems to register two buffers for it. Any suggestion to avoid doing this?

class FxModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer('_tensor_constant0', torch.empty([9216, 4096], dtype=torch.float32))
        self.register_buffer('_tensor_constant1', torch.empty([4096, 4096], dtype=torch.float32))
        self.register_buffer('_tensor_constant2', torch.empty([4096, 1000], dtype=torch.float32))
        self.register_buffer('_tensor_constant3', torch.empty([1000, 4096], dtype=torch.float32))
        self.register_buffer('_tensor_constant4', torch.empty([4096, 4096], dtype=torch.float32))
        self.register_buffer('_tensor_constant5', torch.empty([4096, 9216], dtype=torch.float32))
        self._param_constant0 = torch.nn.Parameter(torch.empty([64, 3, 11, 11], dtype=torch.float32))
        self._param_constant1 = torch.nn.Parameter(torch.empty([64], dtype=torch.float32))
        self._param_constant2 = torch.nn.Parameter(torch.empty([192, 64, 5, 5], dtype=torch.float32))
        self._param_constant3 = torch.nn.Parameter(torch.empty([192], dtype=torch.float32))
        self._param_constant4 = torch.nn.Parameter(torch.empty([384, 192, 3, 3], dtype=torch.float32))
        self._param_constant5 = torch.nn.Parameter(torch.empty([384], dtype=torch.float32))
        self._param_constant6 = torch.nn.Parameter(torch.empty([256, 384, 3, 3], dtype=torch.float32))
        self._param_constant7 = torch.nn.Parameter(torch.empty([256], dtype=torch.float32))
        self._param_constant8 = torch.nn.Parameter(torch.empty([256, 256, 3, 3], dtype=torch.float32))
        self._param_constant9 = torch.nn.Parameter(torch.empty([256], dtype=torch.float32))
        self._param_constant10 = torch.nn.Parameter(torch.empty([4096], dtype=torch.float32))
        self._param_constant11 = torch.nn.Parameter(torch.empty([4096], dtype=torch.float32))
        self._param_constant12 = torch.nn.Parameter(torch.empty([1000], dtype=torch.float32))
        self.load_state_dict(torch.load(r'forward_backward/state_dict.pt'))
Chillee commented 1 year ago

@ConnollyLeon If you want to compile it, use aot_module, which will lift up the parameters to inputs of the function.

If you're just trying to accelerate it, you can use memory_efficient_fusion, which has some preconfigured settings that should work well for acceleration on CUDA.

ConnollyLeon commented 1 year ago

@Chillee Thanks for you reply. But why did the weight parameters of Linear layers turns out to become tensor_constant in the FxModule? Could you please help explain this?

Chillee commented 1 year ago

@ConnollyLeon If you trace with aot_function, then it'll only treat the inputs to that function as "changeable values", and it'll assume everything else is constant (including parameters!).