ezyang / pytorch-unattached

Tensors and Dynamic neural networks in Python with strong GPU acceleration
http://pytorch.org
Other
20 stars 8 forks source link

Fuser Improvements #224

Closed zdevito closed 7 years ago

zdevito commented 7 years ago

I looked into what the fuser does with the backwards pass and have a plan to fix it. This PR so far just adds code to visualize the graph, and demonstrate how to fix the problem. I will later add an update with actual code support.

For the word language model the full graph looks like this: https://rawgit.com/zdevito/3d619ef61f698815fe80525f0ad42e97/raw/8430fab99cbfedaba6e0fb43996db75b2a82e69d/before.html

The forward pass cell, FusionGroup_0, is fully fused but the reverse is still in 4 separate fusion groups (6,7,8,9). This is a byproduct of the heuristic that tries to fuse many producers into a single consumer. In the backward case, Concat is not a simple map and hence is not fusable. Because of this, there is not a single output of the backward pass LSTM cell, but rather 4 outputs each starting a new 'seed' fusion group which cannot merge with the other fusion groups.

One approach would be to have better handling for merging adjacent fusion groups together. This can get tricky - for instance, 7 can merge with 9, but only if you observe that 8 can go after 9. We should do this eventually, but we don't need to do it now.

The approach I want to go with is to allow Concat nodes to be an exit node of a FusionGroup. This fixes the issue above. Unlike simply fixing fusions it also makes sure the Concat is not done in a separate pass (which adds kernel launches and uses more memory).

If we allow this then the trace is what we want:

https://rawgit.com/zdevito/104ea16a7234e5688fc62b87cc4da711/raw/4e330d888dad4fef8ae949ffe6e3856dd5ba3faf/after.html

It is valid to fuse a Concat into a group as long as the output of concat (which is no longer the simple map size) is not used in the group, and each individual element being concatenated is the same size (which will be true in this case). The implementation strategy is pretty easy as well: allocate the Concat output before the fusion group runs. Narrow the tensors that form the body of the Concat and then pass those into the fusion kernel as normal outputs.

Finally, a thought about fusions: if we have a valid fusion in the forward pass, then there is always a corresponding fusion for the backward. The gradient of a simple map is still a simple map. This suggests that if we find a forward fusion we like, even if we didn't add new fusion heuristics, we should be able to find the fusion for the gradient. Or equivalently, there exists a dual of our fusion engine that works by fusing consumers into producers scanning in the opposite direction as our current pass.

apaszke commented 7 years ago

Good point with the dual fuser. We could add tracking of function -> derivative edges to the tracer natively (like we do with Handle edges, but now for all nodes), and slice off the backward fusion groups in this way. On the other hand backward fusion groups could also include simple maps appearing in neighboring ops.

zdevito commented 7 years ago

Some numbers now that the fusion in the backward pass works (caveat: because of #230 I cannot verify correctness, though it doesn't crash immediately...):

zdevito commented 7 years ago

This still needs to be rebased onto the jit branch before it is ready to merge. I was waiting because of the hypothesis that there pybind issues. I'll see what happens tomorrow.

ezyang commented 7 years ago

Force pushed rebase

zdevito commented 7 years ago

Some timings for the kernels themselves. These are not precisely the same because the boundaries are not precisely the same, but they are pretty close:

HAND FUSED:
Time(%)      Time     Calls       Avg       Min       Max  Name
1.56%  187.63ms     22330  8.4020us  7.7760us  15.521us  void THNN_CudaLSTMForward<float, unsigned int, int=-2>(TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, unsigned int, unsigned int)
1.08%  129.24ms     22322  5.7900us  5.3760us  13.472us  void THNN_CudaLSTMBackward<float, unsigned int, int=-2>(TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, unsigned int, unsigned int)

OURS:
1.16%  204.53ms     33460  6.1120us  5.6000us  15.456us  kernel_0 // roughly equivalent to forward
0.92%  161.96ms     31548  5.1330us  4.8320us  16.992us  kernel_5 //roughly equivalent to backward
apaszke commented 7 years ago

Also, I don't feel like committing that graph visualizer into PyTorch 😕 Can we just have a mode that emits JSON and then a separate script (you can put it in a gist) that has all HTML strings and can produce the visualization?

zdevito commented 7 years ago

I'm happy to hide the visualizer more (it shouldn't be in onnx.py, but I'd like it to be in the code so it is easily accessible to me. I needed it to figure out how to make our code go fast, and I will need it in the future.

apaszke commented 7 years ago

Is there a problem with doing what I suggested? Let's just have a JSON printer and the we can develop a whole toolkit for visualizing/inspecting this data on the side

zdevito commented 7 years ago

I don't want to have to dig up code that isn't with the repo just because I wanted to write a file to disk to debug the IR. I guess I don't see a serious cost to keeping this code in the repo, but I do see a serious one for not. Longer-term we can invest in better debugging, but right now that doesn't exist and I need something to use.

apaszke commented 7 years ago

I really think this code doesn't belong to the core repo, and I can't see any problem with keeping it on a side. It's not like you have a high overhead for using it. Just paste it in the place where you want to use it (probably word language model) and you're done.

zdevito commented 7 years ago

Well, I can't put it in the application, because I can't get to the code that dumps the IR between optimization passes. It is inside the Traceable class. I also want the benefits of source control for debugging code so that I don't have to dig up the right version of it everytime I want to do something with it.

soumith commented 7 years ago
  1. the visualizer code goes into torch.utils, Adam it's pretty dumb to have to dig up this code everytime you have to visualize JIT traces (which we'll be doing pretty often).

  2. Having the core jit.py code only have a JSON dumper, and the utils code (or directly the html) consuming the html seems cleaner. Though at this particular point, speed of development is more important than stability, so dont give huge emphasis into such decisions.

soumith commented 7 years ago

if the visualizer code can go into torch.contrib.utils, that'd be nicer than torch.utils. We should actively create and start using contrib for such things.