pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.35k stars 22.48k forks source link

[jit] need a better way to handle mix CPU/GPU (Inference/Training) for tracing #43134

Open wanchaol opened 4 years ago

wanchaol commented 4 years ago

Tracing only support specialized device upon the traced inputs. In production many model training happens on GPU, and inference happens on CPU.

Taking a simple example:

import torch
import torch.nn as nn

class MyScriptModule(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input):
        blob = torch.zeros(input.shape[0], device = input.device)
        print("input: ", input.device)
        print("blob: ", blob.device)
        return blob + input

class MyTraceModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.mod = torch.jit.script(MyScriptModule())

    def forward(self, input):
        return self.mod(input)

input = torch.zeros(10)
input_gpu = input.to("cuda:0")

myModule = MyTraceModule()
traced = torch.jit.trace(myModule, (input))
print(traced(input_gpu))

This will crash with the following error message:

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "/private/home/wanchaol/test_device.py", line 12, in forward
        print("input: ", input.device)
        print("blob: ", blob.device)
        return blob + input
               ~~~~~~~~~~~~ <--- HERE
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

This is because we specialized the input TensorType with the device information of the input when we do tracing. We can see a graph for the above program like this:

graph(%self : __torch__.MyTraceModule,
      %4 : Float(10:1, requires_grad=0, device=cpu)):
  %6 : int = prim::Constant[value=0]() # /private/home/wanchaol/test_device.py:9:39
  %7 : None = prim::Constant()
  %8 : str = prim::Constant[value="input: "]() # /private/home/wanchaol/test_device.py:10:14
  %9 : str = prim::Constant[value="blob: "]() # /private/home/wanchaol/test_device.py:11:14
  %10 : int = prim::Constant[value=1]()
  %11 : int[] = aten::size(%4) # <string>:7:9
  %12 : int = aten::__getitem__(%11, %6) # /private/home/wanchaol/test_device.py:9:27
  %13 : Device = prim::device(%4)
  %14 : int[] = prim::ListConstruct(%12)
  %blob.1 : Tensor = aten::zeros(%14, %7, %7, %13, %7) # /private/home/wanchaol/test_device.py:9:15
  %16 : Device = prim::device(%4)
   = prim::Print(%8, %16) # /private/home/wanchaol/test_device.py:10:8
  %17 : Device = prim::device(%blob.1)
   = prim::Print(%9, %17) # /private/home/wanchaol/test_device.py:11:8
  %18 : Tensor = aten::add(%blob.1, %4, %10) # /private/home/wanchaol/test_device.py:12:15
  return (%18)

We can see that %4 is specialized to the CPU device because that's what we get when we did the tracing. But this fails to generalize to GPU as it hard coded the device information.

Our current way to allow training on GPU and inference on CPU is done by tracing it twice with different input on different devices. see https://pytorch.org/docs/stable/jit.html#frequently-asked-questions

But in reality, this might not be feasible, it has several constraints, to name a few:

I think we need a better way of handling the GPU training and CPU inference scenario for tracing, as it's one of the most common use case.

Proposal:

Remove specialization during tracing, tracing only add plain nodes, it depends on the executor to specialize on the first run of the trace.

cc @suo @gmagogsfm

gqchen commented 4 years ago

I tried

traced_gpu = torch.jit.trace(myModule, (input)).to("cuda:0") print(tracedgpu(input_gpu))

This also doesn't work (same error)

wanchaol commented 4 years ago

I tried

traced_gpu = torch.jit.trace(myModule, (input)).to("cuda:0") print(tracedgpu(input_gpu))

This also doesn't work (same error)

we should do traced_gpu = torch.jit.trace(myModule, (input.to("cuda:0"))) instead, this should work, otherwise it will convert the traced model into cuda which only includes params/buffers.

The problem arises in this example above seems also triggering bug in the runtime where prim::device didn't give the correct device during runtime, need to figure out the exact cause of it.