Open ConnollyLeon opened 1 year ago
@ConnollyLeon If you want to compile it, use aot_module
, which will lift up the parameters to inputs of the function.
If you're just trying to accelerate it, you can use memory_efficient_fusion
, which has some preconfigured settings that should work well for acceleration on CUDA.
@Chillee Thanks for you reply. But why did the weight
parameters of Linear layers turns out to become tensor_constant
in the FxModule? Could you please help explain this?
@ConnollyLeon If you trace with aot_function
, then it'll only treat the inputs to that function as "changeable values", and it'll assume everything else is constant (including parameters!).
I am trying to use functorch to train a model in a more JAX-like way. I use the aot_function to get a forward graph module and a backward graph module, but find out that in the backward module, it does not contain the computation of the parameters' gradients.
After reading the source code, I think the below function erases the corresponding computation part, as the parameter gradient calculation is irrelevant to the output of the backward module.
https://github.com/pytorch/functorch/blob/6c3b57f3a3fd54a2f3e3db12c2059669112bed6c/functorch/_src/partitioners.py#L94
I think it would be better for you to offer an api that can capture these important computation in training a neural network. Would you develop this in the future?
Here is the backward module of alexnet that I generates. As you can see, it does not involves the computation of the gradients of parameters.
Another question is that, I also tried to output the code of
joint_forward_backward
. Thejoint_forward_backward
GraphModule contains the computation of weight gradients. But I find out that the parameters ofLinear
layer is redefined in the__init__
method. As you can see in the below code._tensor_constant0
and_tensor_connstant5
are supposed to be the one and the same parameter, but with different shape._tensor_constant5
is the transpose of_tensor_constant0
. But here it seems to register two buffers for it. Any suggestion to avoid doing this?