If you pass a tensor with the wrong shape to conv, it gives this completely incomprehensible error message:
File "/home/ncarlini/.local/lib/python3.6/site-packages/jax/lax/lax.py", line 542, in conv_general_dilated
precision=_canonicalize_precision(precision))
File "/home/ncarlini/.local/lib/python3.6/site-packages/jax/core.py", line 202, in bind
out_tracer = top_trace.process_primitive(self, tracers, kwargs)
File "/home/ncarlini/.local/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 133, in process_primitive
return self.default_process_primitive(primitive, tracers, params)
File "/home/ncarlini/.local/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 141, in default_process_primitive
out_aval = primitive.abstract_eval(*avals, **params)
File "/home/ncarlini/.local/lib/python3.6/site-packages/jax/lax/lax.py", line 1672, in standard_abstract_eval
return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
File "/home/ncarlini/.local/lib/python3.6/site-packages/jax/lax/lax.py", line 2167, in _conv_general_dilated_shape_rule
rhs.shape[dimension_numbers.rhs_spec[1]]))
ValueError: conv_general_dilated lhs feature dimension size divided by feature_group_count must equal the rhs input feature dimension size, but 64 // 1 != 6.
This changes it to say something sane
File "/home/ncarlini/objax/objax/nn/layers.py", line 345, in __call__
args = f(*args, **util.local_kwargs(kwargs, f))
File "/home/ncarlini/objax/objax/nn/layers.py", line 183, in __call__
self.w.value.shape, x.shape)
AssertionError: Attempting to convolve an input with 64 input channels when the convolution expects 6 channels. For reference, self.w.shape=(3, 3, 6, 16) and x.shape=(5000, 64, 32, 3).
If you pass a tensor with the wrong shape to conv, it gives this completely incomprehensible error message:
This changes it to say something sane