ironjr / grokfast

Official repository for the paper "Grokfast: Accelerated Grokking by Amplifying Slow Gradients"
https://arxiv.org/abs/2405.20233
MIT License
517 stars 44 forks source link

gradfilter_ema for Keras/Tensorflow #13

Open snik007 opened 2 months ago

snik007 commented 2 months ago

Any change Keras/Tensorflow version of code to be added?

subtotechnoblade commented 3 weeks ago

Feel free to use my implementation for tensorflow's function API. Just replace model = tf.keras.Model(inputs=inputs, outputs=outputs) with model = Grok_Fast_EMA_Model(alpha=0.99, lamb=5, inputs=inputs, output=outputs) If there are any implementation errors please notify me. Thx!!!

class Grok_Fast_EMA_Model(tf.keras.Model):
    def __init__(self, alpha=0.99, lamb=5.0, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.grads = [tf.Variable(tf.zeros_like(var), trainable=False) for var in self.trainable_variables]
        self.grads_updated = False
        self.alpha = alpha
        self.lamb = lamb

    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data

        with tf.GradientTape() as tape:

            y_pred = self(x, training=True)  # Forward pass
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.compute_loss(y=y, y_pred=y_pred)

            if self.optimizer is not None:
                loss = self.optimizer.scale_loss(loss)

            # gradfilter ema from the grok fast paper and github https://github.com/ironjr/grokfast?tab=readme-ov-file
            trainable_vars = self.trainable_variables
            gradients = tape.gradient(loss, trainable_vars)

            if not self.grads_updated:
                self.grads_updated = True
                for i in range(len(trainable_vars)):
                    if gradients[i] is not None:
                        self.grads[i].assign(tf.convert_to_tensor(gradients[i]))

            updated_gradients = []
            for i in range(len(trainable_vars)):
                if gradients[i] is not None:
                    current_gradients = tf.convert_to_tensor(gradients[i])
                    self.grads[i].assign(self.grads[i].value() * self.alpha + current_gradients * (1 - self.alpha))
                    updated_gradients.append(current_gradients + self.grads[i] * self.lamb)
                else:
                    updated_gradients.append(gradients[i])
        # Update weights
        self.optimizer.apply_gradients(zip(updated_gradients, trainable_vars))
        # Update metrics (includes the metric that tracks the loss)

        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)

        return self.compute_metrics(x, y, y_pred, loss)