Closed matpalm closed 1 year ago
I couldn't access the colab. A priori, I don't see why it wouldn't work, so I need to take a look at the details to understand better what's happening.
oh sorry, must have misclicked during sharing. can you try again please?
https://colab.research.google.com/drive/1P9cmW24V4fRDQmay3G8ZtRQtybNDH6ij?usp=sharing
I think it must be because you don't pass the model vars to Jit, and yet you pass training=True
which is going to update batch norm statistics for the ResNet. This makes the code confused, because you change the value of a variable that Jit is treating as a constant (since it wasn't provided to it as a variable).
def train_op(x, y_true, learning_rate):
gradients, values = grad_values(x, y_true)
optimiser(lr=learning_rate, grads=gradients)
return values
train_op = objax.Jit(train_op, optimiser.vars() + grad_values.vars())
One way to solve the problem would be (assuming you want to update batch norm statistics in the resnet):
train_op = objax.Jit(train_op, optimiser.vars() + model.vars())
Another way to solve it is if you don't want to update the batch norm statistics, you can simply do this:
def loss_fn(x, y_true):
logits = model(x, training=False)
return cross_entropy_logits_sparse(logits, y_true).mean()
Of course you can do both solutions at the same time. Generally speaking passing all variables to Jit (model ones, gradient ones, optimisers, etc...) is probably the safest.
I've marked this Issue for the project on error message improvement. We can look into whether there's a way we could detect a variable being written in JIT while it is not passed in the VarCollection.
ahha! i see. thanks, as always, for the quick feedback david!
@carlini I think you looked into similar issues, could you provide your suggestions
hi!
doing some calibration experiments and playing around with a simple Temperature scaling layer for just doing elementwise scaling.
the idea to be that you can just append it to a pretrained resnet and just fit the single temp scaling; something like...
when i do this though my training loop works, but the model then gets into a strange state when forward pass fails. things work ok when i run the optimisation loop against all
model.vars()
though :/a minimal reproduction is provided in this colab
this looks like some kind of bug, but maybe there's just a simpler way to represent just training for this logit rescaling?
cheers, mat