google / objax

Apache License 2.0
772 stars 77 forks source link

Improve error when calling replicated function without Parallel #155

Open carlini opened 3 years ago

carlini commented 3 years ago

Currently if you call replicate() and then don't use Parallel() to call a function, you get a bad error message. Try this code:

import objax
import numpy as np

mod = objax.nn.Conv2D(2, 4, 3)

with mod.vars().replicate():
    print(mod(np.ones((8,2,10,10))))

the error says

Traceback (most recent call last):
  File "b.py", line 9, in <module>
    print(mod(np.ones((8,2,10,10))))
  File "/opt/conda/lib/python3.7/site-packages/objax/nn/layers.py", line 185, in __call__
    dimension_numbers=('NCHW', 'HWIO', 'NCHW'))
  File "/opt/conda/lib/python3.7/site-packages/jax/_src/lax/lax.py", line 555, in conv_general_dilated
    dnums = conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers)
  File "/opt/conda/lib/python3.7/site-packages/jax/_src/lax/lax.py", line 5955, in conv_dimension_numbers
    raise TypeError(msg.format(len(lhs_shape), len(rhs_shape)))
TypeError: convolution requires lhs and rhs ndim to be equal, got 4 and 5.

which obviously means that you accidentally are evaluating a function that was replicated without wrapping in Parallel.

AlexeyKurakin commented 3 years ago

One possible solution - if function is called with SharderDeviceArray then it must be replicated. In the example above - mod.__call__ checks whether input is SharderDeviceArray or not. If it is SharderDeviceArray then exception is thrown saying that user have to call objax.Parallel