google / objax

Apache License 2.0
769 stars 77 forks source link

TrainVar with single scalar causing error #164

Closed matpalm closed 1 year ago

matpalm commented 3 years ago

hi!

doing some calibration experiments and playing around with a simple Temperature scaling layer for just doing elementwise scaling.

class Temperature(objax.module.Module):

    def __init__(self):
        super().__init__()                
        self.temperature = objax.variable.TrainVar(jnp.array([1.0]))

    def __call__(self, x):        
        return x / self.temperature.value

the idea to be that you can just append it to a pretrained resnet and just fit the single temp scaling; something like...

model = ResNet18(in_channels=3, num_classes=10)
objax.io.load_var_collection('model.npz', model.vars())

temperature_layer = Temperature()
model.append(temperature_layer)

optimiser = objax.optimizer.Adam(temperature_layer.vars())
etc

when i do this though my training loop works, but the model then gets into a strange state when forward pass fails. things work ok when i run the optimisation loop against all model.vars() though :/

a minimal reproduction is provided in this colab

this looks like some kind of bug, but maybe there's just a simpler way to represent just training for this logit rescaling?

cheers, mat

david-berthelot commented 3 years ago

I couldn't access the colab. A priori, I don't see why it wouldn't work, so I need to take a look at the details to understand better what's happening.

matpalm commented 3 years ago

oh sorry, must have misclicked during sharing. can you try again please?

https://colab.research.google.com/drive/1P9cmW24V4fRDQmay3G8ZtRQtybNDH6ij?usp=sharing

david-berthelot commented 3 years ago

I think it must be because you don't pass the model vars to Jit, and yet you pass training=True which is going to update batch norm statistics for the ResNet. This makes the code confused, because you change the value of a variable that Jit is treating as a constant (since it wasn't provided to it as a variable).

def train_op(x, y_true, learning_rate):
    gradients, values = grad_values(x, y_true)
    optimiser(lr=learning_rate, grads=gradients)
    return values

train_op = objax.Jit(train_op, optimiser.vars() + grad_values.vars())

One way to solve the problem would be (assuming you want to update batch norm statistics in the resnet):

train_op = objax.Jit(train_op, optimiser.vars() + model.vars())

Another way to solve it is if you don't want to update the batch norm statistics, you can simply do this:

def loss_fn(x, y_true):
    logits = model(x, training=False)
    return cross_entropy_logits_sparse(logits, y_true).mean()

Of course you can do both solutions at the same time. Generally speaking passing all variables to Jit (model ones, gradient ones, optimisers, etc...) is probably the safest.

david-berthelot commented 3 years ago

I've marked this Issue for the project on error message improvement. We can look into whether there's a way we could detect a variable being written in JIT while it is not passed in the VarCollection.

matpalm commented 3 years ago

ahha! i see. thanks, as always, for the quick feedback david!

AlexeyKurakin commented 3 years ago

@carlini I think you looked into similar issues, could you provide your suggestions