def loss(x, label):
y_hat = model(x)[:, 0] # This is the right one
y_hat = model(x) # This will not work
return objax.functional.loss.mean_squared_error(y_hat, label, keep_axis=None)
When called with:
label.shape: (1000,)
model(x): (1000,1)
mean_squared_error will not calculate what the user expects unless the user explicitly reshapes the input arguments.
Loss functions (https://objax.readthedocs.io/en/latest/objax/functional.html?highlight=loss#objax-functional-loss) take two arguments which should have the same dimensions. However the functions do not check whether the dimensions are the same and when they are not they calculate something other than what the user expects.
For example, consider the following function:
def loss(x, label): y_hat = model(x)[:, 0] # This is the right one y_hat = model(x) # This will not work return objax.functional.loss.mean_squared_error(y_hat, label, keep_axis=None)
When called with:
label.shape: (1000,) model(x): (1000,1)
mean_squared_error will not calculate what the user expects unless the user explicitly reshapes the input arguments.