luyug / GradCache

Run Effective Large Batch Contrastive Learning Beyond GPU/TPU Memory Constraint
Apache License 2.0
346 stars 19 forks source link

Implement Grokfast into GradCache #31

Open ben-walczak opened 3 months ago

ben-walczak commented 3 months ago

I would like to implement the algorithm for grokfast, which is an exponentially weighted mean of past gradients added to the current gradients, with GradCache. I've been able to use it without GradCache, but I'm confused where it could be implemented with GradCache as I'm still learning the underlying mechanisms of GradCache. Any direction on how this might be done? Also curious if this would be an appropriate feature to this library

luyug commented 3 months ago

an exponentially weighted mean of past gradients added to the current gradients

shouldn't this be tracked in the optimizer (state) and applied as a gradient transformation?

ben-walczak commented 2 months ago

I'm not entirely sure. Just spitballing but I think it would be implemented somewhere within these lines of codes: https://github.com/luyug/GradCache/blob/main/src/grad_cache/grad_cache.py#L193C9-L211

and use similar logic to the following:

def gradfilter_ema(
    m: nn.Module,
    grads: Optional[Dict[str, torch.Tensor]] = None,
    alpha: float = 0.99,
    lamb: float = 5.0,
) -> Dict[str, torch.Tensor]:
    if grads is None:
        grads = {n: p.grad.data.detach() for n, p in m.named_parameters() if p.requires_grad}

    for n, p in m.named_parameters():
        if p.requires_grad:
            grads[n] = grads[n] * alpha + p.grad.data.detach() * (1 - alpha)
            p.grad.data = p.grad.data + grads[n] * lamb

    return grads

I'll test this out later when I get a chance