Closed XuehaiPan closed 1 year ago
In 2.0 you can access the APIs from both torch.func
and functorch
. This will be true for the foreseeable future (i.e., we will preserve BC for a few releases of PyTorch). However, there will be differences between the torch.func.*
and functorch.*
APIs
In general, we're deprecating the functorch. APIs in favor of the torch.func. APIs. As a part of this deprecation, we're moving away from functorch.make_functional and consolidating on PyTorch's NN stateless API. More details over at https://github.com/pytorch/pytorch/pull/91811
In general, we're deprecating the
functorch.*
APIs in favor of thetorch.func.*
APIs.
@zou3519 Thanks for the comment.
As a part of this deprecation, we're moving away from
functorch.make_functional
and consolidating on PyTorch's NN stateless API. More details over at https://github.com/pytorch/pytorch/pull/91811
One more question about the memory usage of PyTorch's NN stateless API.
We will prefer torch.func.functional_call
over functorch.make_functional
in the future:
import functorch
import torch
model = ... # build NN module
# functional_call
params_and_buffers_dict = ... # extract parameters or user-defined tensor dicts
output = torch.func.functional_call(model, params_and_buffers_dict, args=args, kwargs=kwargs)
# make_functional
fmodel, params, buffers = functorch.make_functional_with_buffers(model)
output = fmodel(params, buffers, *args, **kwargs)
In functorch.make_functional{,_with_buffers}
, the copied stateless module converts tensors with meta
device, which does not hold data storage. This makes the fmodel
use significantly less memory than the original module. Now, the nn.utils.stateless.functional_call
API requires the user to pass a full model and a new copy of parameters. That is twice the memory usage. This memory problem may be exacerbated when multi-process communication is required.
For example, stateless functional call over RPC:
import torch
import torch.distributed.rpc as rpc
model = ... # build NN module
# functional_call
params_and_buffers_dict = ... # extract parameters or user-defined tensor dicts
output = rpc.rpc_sync(
'worker1',
torch.func.functional_call,
args=(
model, # the original parameters also need communication
params_and_buffers_dict,
),
kwargs=dict(args=args, kwargs=kwargs),
)
# make_functional
fmodel, params, buffers = functorch.make_functional_with_buffers(model)
output = rpc.rpc_sync(
'worker1',
fmodel, # small serialization and communication overhead
args=(params, buffers, *args),
kwargs=kwargs,
)
In order to have less communication cost, users need to explicitly convert the tensors to meta
device before stateless functional call on remote workers:
import torch
import torch.distributed.rpc as rpc
model = ... # build NN module
# functional_call
params_and_buffers_dict = ... # extract parameters or user-defined tensor dicts
output = rpc.rpc_sync(
'worker1',
torch.func.functional_call,
args=(
model.to('meta'), # convert to meta device
params_and_buffers_dict,
),
kwargs=dict(args=args, kwargs=kwargs),
)
In functorch.make_functional{,_with_buffers}, the copied stateless module converts tensors with meta device, which does not hold data storage. This makes the fmodel use significantly less memory than the original module. Now, the nn.utils.stateless.functional_call API requires the user to pass a full model and a new copy of parameters.
Yes, to avoid using twice the amount of memory, then users need to explicitly convert tensors to meta
device.
Concretely, it depends on where the user wants to store their parameters:
meta
device, since those are not going to be used.Does that alleviate the concern? If so, I'll update the migration guide to reflect this -- thank you for your feedback.
Thanks for the comments. Now I have no more questions about migration.
In PyTorch pre-1.13, the
functorch
is a separate PyPI package, users need manually install the package:to make the module available:
Now, as of PyTorch 1.13.x, the
functorch
module is out-of-box available when the user installstorch >= 1.13
:The source code of
functorch
is also migrated to thepytorch/pytorch
repo underpytorch/pytorch/torch/_functorch@master
.I found in the master branch of the
pytorch/pytorch
repo, there is a new subpackagetorch.func
is added. And some of the eager transformations infunctorch
are populated there.If the user install the develop version of
torch
:then both of the followings are available:
However, there are some of the functions in
functorch
are not populated intorch.func
yet:So they are not currently mutually interchangeable.
I wonder the packaging policy for
functorch
andtorch.func
in the next release of PyTorch (torch
) 2.0 (or 1.14). Would the user have a single import statementimport torch
to use all methods infunctorch
astorch.func
?