google / objax

Apache License 2.0
772 stars 77 forks source link

Improve error when recursively calling Parallel #157

Open carlini opened 4 years ago

carlini commented 4 years ago

Currently if you have a parallel function recursively call itself, you can get some incomprehensible error messages.

This is very low priority.

    import objax
    import jax.numpy as jn
    import numpy as np

    mod = objax.nn.Conv2D(2, 4, 3)

    with mod.vars().replicate():
        def ell(x):
            return p(x)
        m = objax.Grad(ell, {}, (0,))
        p = objax.Parallel(m, mod.vars(), reduce=lambda x: x)
        print(p(np.ones((8*8,2,10,10))))

The error for this is

  File "/opt/conda/lib/python3.7/threading.py", line 890, in _bootstrap
    self._bootstrap_inner()
  File "/opt/conda/lib/python3.7/threading.py", line 926, in _bootstrap_inner
    self.run()
  File "/opt/conda/lib/python3.7/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.7/socketserver.py", line 650, in process_request_thread
    self.finish_request(request, client_address)
  File "/opt/conda/lib/python3.7/socketserver.py", line 360, in finish_request
    self.RequestHandlerClass(request, client_address, self)
  File "/opt/conda/lib/python3.7/socketserver.py", line 720, in __init__
    self.handle()
  File "/opt/conda/lib/python3.7/site-packages/werkzeug/serving.py", line 345, in handle
    BaseHTTPRequestHandler.handle(self)
  File "/opt/conda/lib/python3.7/http/server.py", line 426, in handle
    self.handle_one_request()
  File "/opt/conda/lib/python3.7/site-packages/werkzeug/serving.py", line 379, in handle_one_request
    return self.run_wsgi()
  File "/opt/conda/lib/python3.7/site-packages/werkzeug/serving.py", line 323, in run_wsgi
    execute(self.server.app)
  File "/opt/conda/lib/python3.7/site-packages/werkzeug/serving.py", line 312, in execute
    application_iter = app(environ, start_response)
  File "/opt/conda/lib/python3.7/site-packages/flask/app.py", line 2464, in __call__
    return self.wsgi_app(environ, start_response)
  File "/opt/conda/lib/python3.7/site-packages/flask/app.py", line 2447, in wsgi_app
    response = self.full_dispatch_request()
  File "/opt/conda/lib/python3.7/site-packages/flask/app.py", line 1950, in full_dispatch_request
    rv = self.dispatch_request()
  File "/opt/conda/lib/python3.7/site-packages/flask/app.py", line 1936, in dispatch_request
    return self.view_functions[rule.endpoint](**req.view_args)
  File "background.py", line 50, in do
    return do_it(name, function, args)
  File "background.py", line 38, in do_it
    res = getattr(module, which)(*args, **kwargs)
  File "/home/ncarlini/diagnosing-failures/debug.py", line 132, in run
    print(p(np.ones((8*8,2,10,10))))
  File "/opt/conda/lib/python3.7/site-packages/objax/module.py", line 282, in __call__
    output, changes = self._call(self.vc.tensors(), self.vc.subset(RandomState).tensors(), *args)
  File "/opt/conda/lib/python3.7/site-packages/objax/module.py", line 251, in pmap
    return f(*args), self.vc.tensors()
  File "/opt/conda/lib/python3.7/site-packages/objax/gradient.py", line 112, in __call__
    return super().__call__(*args, **kwargs)[0]
  File "/opt/conda/lib/python3.7/site-packages/objax/gradient.py", line 79, in __call__
    list(args), kwargs)
  File "/opt/conda/lib/python3.7/site-packages/objax/gradient.py", line 56, in f_func
    outputs = f(*list_args, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/objax/module.py", line 165, in __call__
    return self.__wrapped__(*args, **kwargs)
  File "/home/ncarlini/diagnosing-failures/debug.py", line 126, in ell
    return p(x)
  File "/opt/conda/lib/python3.7/site-packages/objax/module.py", line 282, in __call__
    output, changes = self._call(self.vc.tensors(), self.vc.subset(RandomState).tensors(), *args)
jax.traceback_util.FilteredStackTrace: ValueError: pmap got inconsistent sizes for array axes to be mapped:
the tree of axis sizes is:
(([4, 3], [], 8), {})

And figuring out what this means is more or less impossible.

david-berthelot commented 4 years ago

Indeed, there are multiple errors interacting here, not sure on how to catch what.

For your particular example:

  1. Modules should be created before variable replication.
  2. Calling Grad on a Parallel module is undefined (but that should be catchable).
import objax
import jax.numpy as jn
import numpy as np

# 1. Create modules
mod = objax.nn.Conv2D(2, 4, 3)

def ell(x):
    return mod(x)  # before it was using p, I assume you meant using mod.

m = objax.Grad(ell, objax.VarCollection(), (0,))
p = objax.Parallel(m, mod.vars(), reduce=lambda x: x)

# 2. Replicate vars before using modules.
with mod.vars().replicate():
    print(p(np.ones((8*8,2,10,10))))
carlini commented 4 years ago

No that code is exactly what I meant. It's insane code, but it minified from a bug I actually had. I meant calling p() in ell. That's what makes things crash so badly.

david-berthelot commented 4 years ago

Okay, so can suggest a few things we could catch in your example. Like ideally what should the error message(s) tell in this case?

carlini commented 4 years ago

Yeah. I'm not sure yet is the issue. This code is obviously wrong and stupid. But I don't know the "right" way to say that something has gone wrong with it. Maybe the recursive call into parallel is where things go bad? Probably that should never happen. But it seems unfortunate to have to make the codebase uglier if we're going to explicitly check for loops.