pytorch / tensorpipe

A tensor-aware point-to-point communication primitive for machine learning
Other
249 stars 75 forks source link

Meta device #444

Closed pbelevich closed 2 years ago

pbelevich commented 2 years ago

We want to support meta device in PyTorch RPC framework. One of possible usecases is creating a large model on meta device for further fx tracing:

  1. create a large model on meta device on master node which doesn't have enough memory to materialize the model.
  2. trace and split it to submodules using PiPPy
  3. send each submodule(on meta device) over RPC to the corresponding workers
  4. move those submodules to cuda device(materialize/initialize probably using torchdistx)

backbone for https://github.com/pytorch/pytorch/pull/76882