Open ben-walczak opened 3 months ago
an exponentially weighted mean of past gradients added to the current gradients
shouldn't this be tracked in the optimizer (state) and applied as a gradient transformation?
I'm not entirely sure. Just spitballing but I think it would be implemented somewhere within these lines of codes: https://github.com/luyug/GradCache/blob/main/src/grad_cache/grad_cache.py#L193C9-L211
and use similar logic to the following:
def gradfilter_ema(
m: nn.Module,
grads: Optional[Dict[str, torch.Tensor]] = None,
alpha: float = 0.99,
lamb: float = 5.0,
) -> Dict[str, torch.Tensor]:
if grads is None:
grads = {n: p.grad.data.detach() for n, p in m.named_parameters() if p.requires_grad}
for n, p in m.named_parameters():
if p.requires_grad:
grads[n] = grads[n] * alpha + p.grad.data.detach() * (1 - alpha)
p.grad.data = p.grad.data + grads[n] * lamb
return grads
I'll test this out later when I get a chance
I would like to implement the algorithm for grokfast, which is an exponentially weighted mean of past gradients added to the current gradients, with GradCache. I've been able to use it without GradCache, but I'm confused where it could be implemented with GradCache as I'm still learning the underlying mechanisms of GradCache. Any direction on how this might be done? Also curious if this would be an appropriate feature to this library