Open wanchaol opened 4 years ago
I tried
traced_gpu = torch.jit.trace(myModule, (input)).to("cuda:0") print(tracedgpu(input_gpu))
This also doesn't work (same error)
I tried
traced_gpu = torch.jit.trace(myModule, (input)).to("cuda:0") print(tracedgpu(input_gpu))
This also doesn't work (same error)
we should do traced_gpu = torch.jit.trace(myModule, (input.to("cuda:0")))
instead, this should work, otherwise it will convert the traced model into cuda which only includes params/buffers.
The problem arises in this example above seems also triggering bug in the runtime where prim::device
didn't give the correct device during runtime, need to figure out the exact cause of it.
Tracing only support specialized device upon the traced inputs. In production many model training happens on GPU, and inference happens on CPU.
Taking a simple example:
This will crash with the following error message:
This is because we specialized the input
TensorType
with the device information of the input when we do tracing. We can see a graph for the above program like this:We can see that
%4
is specialized to the CPU device because that's what we get when we did the tracing. But this fails to generalize to GPU as it hard coded the device information.Our current way to allow training on GPU and inference on CPU is done by tracing it twice with different input on different devices. see https://pytorch.org/docs/stable/jit.html#frequently-asked-questions
But in reality, this might not be feasible, it has several constraints, to name a few:
I think we need a better way of handling the GPU training and CPU inference scenario for tracing, as it's one of the most common use case.
Proposal:
Remove specialization during tracing, tracing only add plain nodes, it depends on the executor to specialize on the first run of the trace.
cc @suo @gmagogsfm