Open carlini opened 4 years ago
Indeed, there are multiple errors interacting here, not sure on how to catch what.
For your particular example:
import objax
import jax.numpy as jn
import numpy as np
# 1. Create modules
mod = objax.nn.Conv2D(2, 4, 3)
def ell(x):
return mod(x) # before it was using p, I assume you meant using mod.
m = objax.Grad(ell, objax.VarCollection(), (0,))
p = objax.Parallel(m, mod.vars(), reduce=lambda x: x)
# 2. Replicate vars before using modules.
with mod.vars().replicate():
print(p(np.ones((8*8,2,10,10))))
No that code is exactly what I meant. It's insane code, but it minified from a bug I actually had. I meant calling p()
in ell
. That's what makes things crash so badly.
Okay, so can suggest a few things we could catch in your example. Like ideally what should the error message(s) tell in this case?
Yeah. I'm not sure yet is the issue. This code is obviously wrong and stupid. But I don't know the "right" way to say that something has gone wrong with it. Maybe the recursive call into parallel is where things go bad? Probably that should never happen. But it seems unfortunate to have to make the codebase uglier if we're going to explicitly check for loops.
Currently if you have a parallel function recursively call itself, you can get some incomprehensible error messages.
This is very low priority.
The error for this is
And figuring out what this means is more or less impossible.