Jannoshh / simple-sam

Sharpness-Aware Minimization for Efficiently Improving Generalization
MIT License
41 stars 9 forks source link

Customize fit method #1

Closed SergioG-M closed 3 years ago

SergioG-M commented 3 years ago

Hi! Thanks for the code, I was waiting for a tf.keras implementation of SAM. I wonder if you have though about how to implement this by customizing the fit method (I prefer this to the "full" custom training loop so I can use callbacks easily). I think that in order to work the SAM class would have to inherit from tf.keras.optimizers.Optimizer, but I am not sure how to make your code work in that case. Do you have any idea?

Jannoshh commented 3 years ago

Hi! Sorry for the late response. I think what you want should be simple, just follow https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit/ . Ultimatively, you have to override the train_step method of the model. I uploaded a MNIST example for you, hope it helps.