google / objax

Apache License 2.0
769 stars 77 forks source link

Need better error message for max_pool_2d #154

Open david-berthelot opened 3 years ago

david-berthelot commented 3 years ago
import objax
x = objax.random.uniform((10,3,4,4))
objax.functional.max_pool_2d(x, x) 
# Error: I passed a tensor for the second argument that expects a size (int or tuple of ints)

Error message and stack trace are quite obscure:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/user/jax3/lib/python3.8/site-packages/objax/functional/core/pooling.py", line 105, in max_pool_2d
    return lax.reduce_window(x, -jn.inf, lax.max, (1, 1) + size, (1, 1) + strides, padding=padding)
  File "/home/user/jax3/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 1218, in reduce_window
    return monoid_reducer(operand, window_dimensions, window_strides, padding,
  File "/home/user/jax3/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 1282, in _reduce_window_max
    return reduce_window_max_p.bind(
  File "/home/user/jax3/lib/python3.8/site-packages/jax/core.py", line 270, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "/home/user/jax3/lib/python3.8/site-packages/jax/core.py", line 580, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/user/jax3/lib/python3.8/site-packages/jax/interpreters/xla.py", line 235, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
TypeError: unhashable type: 'numpy.ndarray'