google / objax

Apache License 2.0
769 stars 77 forks source link

Give better error when passing wrong input shape to convolution #99

Closed carlini closed 4 years ago

carlini commented 4 years ago

If you pass a tensor with the wrong shape to conv, it gives this completely incomprehensible error message:

  File "/home/ncarlini/.local/lib/python3.6/site-packages/jax/lax/lax.py", line 542, in conv_general_dilated
    precision=_canonicalize_precision(precision))
  File "/home/ncarlini/.local/lib/python3.6/site-packages/jax/core.py", line 202, in bind
    out_tracer = top_trace.process_primitive(self, tracers, kwargs)
  File "/home/ncarlini/.local/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 133, in process_primitive
    return self.default_process_primitive(primitive, tracers, params)
  File "/home/ncarlini/.local/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 141, in default_process_primitive
    out_aval = primitive.abstract_eval(*avals, **params)
  File "/home/ncarlini/.local/lib/python3.6/site-packages/jax/lax/lax.py", line 1672, in standard_abstract_eval
    return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
  File "/home/ncarlini/.local/lib/python3.6/site-packages/jax/lax/lax.py", line 2167, in _conv_general_dilated_shape_rule
    rhs.shape[dimension_numbers.rhs_spec[1]]))
ValueError: conv_general_dilated lhs feature dimension size divided by feature_group_count must equal the rhs input feature dimension size, but 64 // 1 != 6.

This changes it to say something sane

  File "/home/ncarlini/objax/objax/nn/layers.py", line 345, in __call__                                                                                                                                                               
    args = f(*args, **util.local_kwargs(kwargs, f))                                                                                                                                                                                   
  File "/home/ncarlini/objax/objax/nn/layers.py", line 183, in __call__                                                                                                                                                               
    self.w.value.shape, x.shape)                                                                                                                                                                                                      
AssertionError: Attempting to convolve an input with 64 input channels when the convolution expects 6 channels. For reference, self.w.shape=(3, 3, 6, 16) and x.shape=(5000, 64, 32, 3).