Closed zdevito closed 7 years ago
Good point with the dual fuser. We could add tracking of function -> derivative edges to the tracer natively (like we do with Handle edges, but now for all nodes), and slice off the backward fusion groups in this way. On the other hand backward fusion groups could also include simple maps appearing in neighboring ops.
Some numbers now that the fusion in the backward pass works (caveat: because of #230 I cannot verify correctness, though it doesn't crash immediately...):
This still needs to be rebased onto the jit branch before it is ready to merge. I was waiting because of the hypothesis that there pybind issues. I'll see what happens tomorrow.
Force pushed rebase
Some timings for the kernels themselves. These are not precisely the same because the boundaries are not precisely the same, but they are pretty close:
HAND FUSED:
Time(%) Time Calls Avg Min Max Name
1.56% 187.63ms 22330 8.4020us 7.7760us 15.521us void THNN_CudaLSTMForward<float, unsigned int, int=-2>(TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, unsigned int, unsigned int)
1.08% 129.24ms 22322 5.7900us 5.3760us 13.472us void THNN_CudaLSTMBackward<float, unsigned int, int=-2>(TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, unsigned int, unsigned int)
OURS:
1.16% 204.53ms 33460 6.1120us 5.6000us 15.456us kernel_0 // roughly equivalent to forward
0.92% 161.96ms 31548 5.1330us 4.8320us 16.992us kernel_5 //roughly equivalent to backward
Also, I don't feel like committing that graph visualizer into PyTorch 😕 Can we just have a mode that emits JSON and then a separate script (you can put it in a gist) that has all HTML strings and can produce the visualization?
I'm happy to hide the visualizer more (it shouldn't be in onnx.py
, but I'd like it to be in the code so it is easily accessible to me. I needed it to figure out how to make our code go fast, and I will need it in the future.
Is there a problem with doing what I suggested? Let's just have a JSON printer and the we can develop a whole toolkit for visualizing/inspecting this data on the side
I don't want to have to dig up code that isn't with the repo just because I wanted to write a file to disk to debug the IR. I guess I don't see a serious cost to keeping this code in the repo, but I do see a serious one for not. Longer-term we can invest in better debugging, but right now that doesn't exist and I need something to use.
I really think this code doesn't belong to the core repo, and I can't see any problem with keeping it on a side. It's not like you have a high overhead for using it. Just paste it in the place where you want to use it (probably word language model) and you're done.
Well, I can't put it in the application, because I can't get to the code that dumps the IR between optimization passes. It is inside the Traceable class. I also want the benefits of source control for debugging code so that I don't have to dig up the right version of it everytime I want to do something with it.
the visualizer code goes into torch.utils, Adam it's pretty dumb to have to dig up this code everytime you have to visualize JIT traces (which we'll be doing pretty often).
Having the core jit.py
code only have a JSON dumper, and the utils code (or directly the html) consuming the html seems cleaner. Though at this particular point, speed of development is more important than stability, so dont give huge emphasis into such decisions.
if the visualizer code can go into torch.contrib.utils
, that'd be nicer than torch.utils
. We should actively create and start using contrib for such things.
I looked into what the fuser does with the backwards pass and have a plan to fix it. This PR so far just adds code to visualize the graph, and demonstrate how to fix the problem. I will later add an update with actual code support.
For the word language model the full graph looks like this: https://rawgit.com/zdevito/3d619ef61f698815fe80525f0ad42e97/raw/8430fab99cbfedaba6e0fb43996db75b2a82e69d/before.html
The forward pass cell, FusionGroup_0, is fully fused but the reverse is still in 4 separate fusion groups (6,7,8,9). This is a byproduct of the heuristic that tries to fuse many producers into a single consumer. In the backward case, Concat is not a simple map and hence is not fusable. Because of this, there is not a single output of the backward pass LSTM cell, but rather 4 outputs each starting a new 'seed' fusion group which cannot merge with the other fusion groups.
One approach would be to have better handling for merging adjacent fusion groups together. This can get tricky - for instance, 7 can merge with 9, but only if you observe that 8 can go after 9. We should do this eventually, but we don't need to do it now.
The approach I want to go with is to allow Concat nodes to be an exit node of a FusionGroup. This fixes the issue above. Unlike simply fixing fusions it also makes sure the Concat is not done in a separate pass (which adds kernel launches and uses more memory).
If we allow this then the trace is what we want:
https://rawgit.com/zdevito/104ea16a7234e5688fc62b87cc4da711/raw/4e330d888dad4fef8ae949ffe6e3856dd5ba3faf/after.html
It is valid to fuse a Concat into a group as long as the output of concat (which is no longer the simple map size) is not used in the group, and each individual element being concatenated is the same size (which will be true in this case). The implementation strategy is pretty easy as well: allocate the Concat output before the fusion group runs. Narrow the tensors that form the body of the Concat and then pass those into the fusion kernel as normal outputs.
Finally, a thought about fusions: if we have a valid fusion in the forward pass, then there is always a corresponding fusion for the backward. The gradient of a simple map is still a simple map. This suggests that if we find a forward fusion we like, even if we didn't add new fusion heuristics, we should be able to find the fusion for the gradient. Or equivalently, there exists a dual of our fusion engine that works by fusing consumers into producers scanning in the opposite direction as our current pass.