google / objax

Apache License 2.0
769 stars 77 forks source link

Raise warning when arguments to loss functions are not of the same dimensions #190

Closed aterzis-google closed 3 years ago

aterzis-google commented 3 years ago

Loss functions (https://objax.readthedocs.io/en/latest/objax/functional.html?highlight=loss#objax-functional-loss) take two arguments which should have the same dimensions. However the functions do not check whether the dimensions are the same and when they are not they calculate something other than what the user expects.

For example, consider the following function:

def loss(x, label): y_hat = model(x)[:, 0] # This is the right one y_hat = model(x) # This will not work return objax.functional.loss.mean_squared_error(y_hat, label, keep_axis=None)

When called with:

label.shape: (1000,) model(x): (1000,1)

mean_squared_error will not calculate what the user expects unless the user explicitly reshapes the input arguments.