HEmile / storchastic

Stochastic Automatic Differentiation library for PyTorch.
GNU General Public License v3.0
180 stars 5 forks source link

Recursion limit reached with many cost nodes #87

Closed csmith49 closed 3 years ago

csmith49 commented 3 years ago

Running backward() when there are a lot of cost nodes causes a recursion-limit-reached exception to be thrown.

I'm able to consistently generate this exception running the following:

import storch
import torch
from torch.distributions import Bernoulli
from storch.method import GumbelSoftmax

torch.manual_seed(0)

p = torch.tensor(0.5, requires_grad=True)

for i in range(1000):
    sample = GumbelSoftmax(f"sample_{i}")(Bernoulli(p))
    storch.add_cost(sample, f"cost_{i}")

storch.backward()

which produces the following trace:

Traceback (most recent call last):
  File "...", line 15, in <module>
    storch.backward()
  File ".../lib/python3.8/site-packages/storch/inference.py", line 217, in backward
    accum_loss._clean()
  File ".../lib/python3.8/site-packages/storch/tensor.py", line 401, in _clean
    node._clean()
  File ".../lib/python3.8/site-packages/storch/tensor.py", line 401, in _clean
    node._clean()
  File ".../lib/python3.8/site-packages/storch/tensor.py", line 401, in _clean
    node._clean()
  [Previous line repeated 994 more times]
  File ".../lib/python3.8/site-packages/storch/tensor.py", line 399, in _clean
    node._clean()
RecursionError: maximum recursion depth exceeded
HEmile commented 3 years ago

Thanks for reporting this, it's something subtle I hadn't taken into account yet. It just reaches over the recursion limit of 1000 in Python.

Do you have a certain use case where graphs that are this deep are necessary? For instance, the above code could easily be parallelized by sampling multiple independent Bernouillis and then summing them, which only creates 3 Tensors.

It is fixable though by better implementing the _clean function so that it doesn't use recursion.

csmith49 commented 3 years ago

We did run into this issue in our implementation, but I wouldn't say such graphs are necessary. We refactored a bit and did the summing trick to fix the issue.

Is "deep" the right word to describe the SCG encoded in the example above? Each path only has one edge. Our implementation had a similar structure: short paths from parameters to cost nodes, but lots of cost nodes. "Wide" feels like a better descriptor.

The above language perhaps explains why the issue was surprising. Increasing the width of such a graph (by increasing the limit on the range, for example) doesn't obviously increase the recursive complexity of a traversal.

HEmile commented 3 years ago

You're right that in theory it's a wide graph. It's an implementation issue then: Behind the scenes, it'll create a very 'deep' computation graph to create the surrogate loss to differentiate (it uses a reduction for the sum over all costs, creating effectively an O(1000) depth computation graph).

That can probably be much more efficient!

HEmile commented 3 years ago

Hello Calvin, I've fixed this by collecting different costs smarter, and optimized some code to make it run in decentish time. There's still a significant overhead with Storchastic but it runs your example with 10.000 costs in a couple of seconds.