ironjr / grokfast

Official repository for the paper "Grokfast: Accelerated Grokking by Amplifying Slow Gradients"
https://arxiv.org/abs/2405.20233
MIT License
476 stars 39 forks source link

Feature/kalman filter #8

Open khari998 opened 2 months ago

khari998 commented 2 months ago

Added a Kalman Filter option for the gradfilter

Motivation: If performance scales with better noise filters, Kalman Filters outperform EMAs for filtering noise in datasets so it may offer better performance for the grokfast algorithm as well over the EMA implementation.

Note: Has not yet been tested. Also sorry for all the annoying formatting changes. I'm using black as my auto formatter.

ironjr commented 2 months ago

Thank you for the interesting update! I believe this could possibly be another publishable work by itself if the results are promising. I will hold this open and unmerged for now until any experimental evidence for the benefits. Thanks!

khari998 commented 2 months ago

No problem 😁

Also, I made a slight judgment call for this implementation. The original Kalman filter calculation uses covariance matrices for the process noise and measurement noise, which results in matrix operations during the prediction and update steps. The standard Kalman filter equations for the prediction and update steps are as follows:

Prediction step:

x_pred = x P_pred = P + Q

Update step:

y = z - x_pred S = P_pred + R K = P_pred S^(-1) x = x_pred + K y P = (I - K) * P_pred

where:

Q is the process noise covariance matrix R is the measurement noise covariance matrix S is the innovation covariance matrix K is the Kalman gain matrix I is the identity matrix

So the original calculation may look something like:

def gradfilter_kalman(
    m: nn.Module,
    grads: Optional[Dict[str, Dict[str, torch.Tensor]]] = None,
    process_noise: float = 1e-4,
    measurement_noise: float = 1e-2,
    lamb: float = 2.0,
) -> Dict[str, Dict[str, torch.Tensor]]:
    if grads is None:
        grads = {
            n: {"x": torch.zeros_like(p.grad.data), "P": torch.eye(p.grad.data.numel()).view(p.grad.data.shape + p.grad.data.shape) * measurement_noise}
            for n, p in m.named_parameters()
            if p.requires_grad and p.grad is not None
        }

    for n, p in m.named_parameters():
        if p.requires_grad and p.grad is not None:
            # Prediction step
            x_pred = grads[n]["x"]
            P_pred = grads[n]["P"] + process_noise * torch.eye(p.grad.data.numel()).view(p.grad.data.shape + p.grad.data.shape)

            # Update step
            y = p.grad.data - x_pred
            S = P_pred + measurement_noise * torch.eye(p.grad.data.numel()).view(p.grad.data.shape + p.grad.data.shape)
            K = P_pred / (P_pred + measurement_noise)
            x = x_pred + K * y
            P = (torch.eye(p.grad.data.numel()).view(p.grad.data.shape + p.grad.data.shape) - K) * P_pred

            # Store updated state
            grads[n]["x"] = x
            grads[n]["P"] = P

            # Apply the filtered gradient
            p.grad.data = p.grad.data + x * lamb

    return grads

However, using covariance matrices in this context can lead to increased computational complexity and memory usage, especially for larger models with a high number of parameters. The matrix operations in the prediction and update steps would have a time complexity of O(num_parameters^2) for each parameter.

To address this, I have opted to use scalar values for the process noise and measurement noise, treating them as constants across all parameters. This simplification reduces the time complexity to O(num_parameters) and avoids the need for matrix operations, making the calculation more efficient as the model size scales up. So while my simplified version may not capture the full covariance information as in the standard Kalman filter, I believe it provides a good balance between computational efficiency and the ability to filter gradients effectively. The scalar noise values still allow the filter to adapt to the characteristics of the gradients and provide smoothing.

If you notice any discrepancies with the standard Kalman filter behavior, this simplification may be the reason why. However, I believe the benefits in terms of reduced time complexity and improved scalability outweigh the potential drawbacks, especially in the context of machine learning models where efficiency is crucial.

Let me know if you have any further questions or if there's anything else I can clarify!

If you need a visual for how this compares to an EMA, I simulated a graph for what it should look like over some fake gradients:

import React from 'react';
import { LineChart, Line, XAxis, YAxis, CartesianGrid, Tooltip, Legend, ResponsiveContainer } from 'recharts';

const generateData = (count, min, max, noise) => {
  const data = [];
  let value = Math.random() * (max - min) + min;

  for (let i = 0; i < count; i++) {
    value += Math.random() * (max - min) * 0.1 - (max - min) * 0.05;
    value = Math.max(min, Math.min(max, value));
    data.push({
      x: i,
      y: value + Math.random() * noise - noise / 2,
    });
  }

  return data;
};

const kalmanFilter = (data, processNoise, measurementNoise) => {
  let x = data[0].y;
  let p = measurementNoise;

  return data.map((point) => {
    const x_pred = x;
    const p_pred = p + processNoise;

    const y = point.y - x_pred;
    const k = p_pred / (p_pred + measurementNoise);
    x = x_pred + k * y;
    p = (1 - k) * p_pred;

    return { x: point.x, kalman: x };
  });
};

const emaFilter = (data, alpha) => {
  let ema = data[0].y;

  return data.map((point) => {
    ema = alpha * point.y + (1 - alpha) * ema;
    return { x: point.x, ema };
  });
};

const KalmanVsEmaPlot = () => {
  const data = generateData(100, -5, 5, 2);
  const kalmanData = kalmanFilter(data, 0.01, 0.1);
  const emaData = emaFilter(data, 0.1);

  const mergedData = data.map((point, i) => ({
    x: point.x,
    y: point.y,
    kalman: kalmanData[i].kalman,
    ema: emaData[i].ema,
  }));

  return (
    <ResponsiveContainer width="100%" height={400}>
      <LineChart data={mergedData} margin={{ top: 5, right: 30, left: 20, bottom: 5 }}>
        <CartesianGrid strokeDasharray="3 3" />
        <XAxis dataKey="x" />
        <YAxis />
        <Tooltip />
        <Legend />
        <Line type="monotone" dataKey="y" stroke="#8884d8" dot={false} name="Gradient" />
        <Line type="monotone" dataKey="kalman" stroke="#82ca9d" name="Simplified Kalman" />
        <Line type="monotone" dataKey="ema" stroke="#ff7300" name="EMA" />        
      </LineChart>
    </ResponsiveContainer>
  );
};

export default KalmanVsEmaPlot;

You should be able to see how much quicker the Kalman can fit to the underlying data. EMA's lag much more.

Zhi0467 commented 1 month ago

Hi, do you have any results on whether the Kalman filter eliminates grokking better that you can share now? @khari998

khari998 commented 1 month ago

Hi, do you have any results on whether the Kalman filter eliminates grokking better that you can share now? @khari998

Unfortunately no. I stated earlier that this is untested.

The reason for submitting the feature was due to the fact that there seemed to be some differences in the results presented based on the noise filter so I wanted to present an option that is a better noise filter than an exponential moving average. I currently don't have access to compute to run the same experiments myself so it is open for any other researchers to evaluate.

Zhi0467 commented 1 month ago

Hi, do you have any results on whether the Kalman filter eliminates grokking better that you can share now? @khari998

Unfortunately no. I stated earlier that this is untested.

The reason for submitting the feature was due to the fact that there seemed to be some differences in the results presented based on the noise filter so I wanted to present an option that is a better noise filter than an exponential moving average. I currently don't have access to compute to run the same experiments myself so it is open for any other researchers to evaluate.

Thanks for the response, I think it's an interesting idea and I have the compute to run it. Do you have suggestions/intuitions for choosing the parameters? (lamb, process_noise, measurement_noise etc) Or should I do a grid search in some range to start testing the filter?

HydrogenBombaklot commented 2 weeks ago

Any update here?