pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

[Question] Packaging policy for `functorch` and `torch.func` #1102

Closed XuehaiPan closed 1 year ago

XuehaiPan commented 1 year ago

In PyTorch pre-1.13, the functorch is a separate PyPI package, users need manually install the package:

pip3 install functorch

to make the module available:

import functorch

Now, as of PyTorch 1.13.x, the functorch module is out-of-box available when the user installs torch >= 1.13:

pip3 install 'torch>=1.13'
import functorch

The source code of functorch is also migrated to the pytorch/pytorch repo under pytorch/pytorch/torch/_functorch@master.


I found in the master branch of the pytorch/pytorch repo, there is a new subpackage torch.func is added. And some of the eager transformations in functorch are populated there.

If the user install the develop version of torch:

pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cu117

then both of the followings are available:

import functorch
import torch

grad_fn = functorch.grad(stateless_objective)
grad_fn = torch.func.grad(stateless_objective)

However, there are some of the functions in functorch are not populated in torch.func yet:

functorch.make_functional(model)   # OK
torch.func.make_functional(model)  # <- no member

So they are not currently mutually interchangeable.

I wonder the packaging policy for functorch and torch.func in the next release of PyTorch (torch) 2.0 (or 1.14). Would the user have a single import statement import torch to use all methods in functorch as torch.func?

zou3519 commented 1 year ago

In 2.0 you can access the APIs from both torch.func and functorch. This will be true for the foreseeable future (i.e., we will preserve BC for a few releases of PyTorch). However, there will be differences between the torch.func.* and functorch.* APIs

In general, we're deprecating the functorch. APIs in favor of the torch.func. APIs. As a part of this deprecation, we're moving away from functorch.make_functional and consolidating on PyTorch's NN stateless API. More details over at https://github.com/pytorch/pytorch/pull/91811

XuehaiPan commented 1 year ago

In general, we're deprecating the functorch.* APIs in favor of the torch.func.* APIs.

@zou3519 Thanks for the comment.

As a part of this deprecation, we're moving away from functorch.make_functional and consolidating on PyTorch's NN stateless API. More details over at https://github.com/pytorch/pytorch/pull/91811

One more question about the memory usage of PyTorch's NN stateless API.

We will prefer torch.func.functional_call over functorch.make_functional in the future:

import functorch
import torch

model = ...  # build NN module

# functional_call
params_and_buffers_dict = ...  # extract parameters or user-defined tensor dicts
output = torch.func.functional_call(model, params_and_buffers_dict, args=args, kwargs=kwargs)

# make_functional
fmodel, params, buffers = functorch.make_functional_with_buffers(model)
output = fmodel(params, buffers, *args, **kwargs)

In functorch.make_functional{,_with_buffers}, the copied stateless module converts tensors with meta device, which does not hold data storage. This makes the fmodel use significantly less memory than the original module. Now, the nn.utils.stateless.functional_call API requires the user to pass a full model and a new copy of parameters. That is twice the memory usage. This memory problem may be exacerbated when multi-process communication is required.

For example, stateless functional call over RPC:

import torch
import torch.distributed.rpc as rpc

model = ...  # build NN module

# functional_call
params_and_buffers_dict = ...  # extract parameters or user-defined tensor dicts
output = rpc.rpc_sync(
    'worker1',
    torch.func.functional_call,
    args=(
        model,  # the original parameters also need communication
        params_and_buffers_dict,
    ),
    kwargs=dict(args=args, kwargs=kwargs),
)

# make_functional
fmodel, params, buffers = functorch.make_functional_with_buffers(model)
output = rpc.rpc_sync(
    'worker1',
    fmodel,  # small serialization and communication overhead
    args=(params, buffers, *args),
    kwargs=kwargs,
)

In order to have less communication cost, users need to explicitly convert the tensors to meta device before stateless functional call on remote workers:

import torch
import torch.distributed.rpc as rpc

model = ...  # build NN module

# functional_call
params_and_buffers_dict = ...  # extract parameters or user-defined tensor dicts
output = rpc.rpc_sync(
    'worker1',
    torch.func.functional_call,
    args=(
        model.to('meta'),  # convert to meta device    
        params_and_buffers_dict,
    ),
    kwargs=dict(args=args, kwargs=kwargs),
)
zou3519 commented 1 year ago

In functorch.make_functional{,_with_buffers}, the copied stateless module converts tensors with meta device, which does not hold data storage. This makes the fmodel use significantly less memory than the original module. Now, the nn.utils.stateless.functional_call API requires the user to pass a full model and a new copy of parameters.

Yes, to avoid using twice the amount of memory, then users need to explicitly convert tensors to meta device.

zou3519 commented 1 year ago

Concretely, it depends on where the user wants to store their parameters:

Does that alleviate the concern? If so, I'll update the migration guide to reflect this -- thank you for your feedback.

XuehaiPan commented 1 year ago

Thanks for the comments. Now I have no more questions about migration.