davda54 / sam

SAM: Sharpness-Aware Minimization (PyTorch)
MIT License
1.76k stars 196 forks source link

Using the step function with closure #90

Open mathuryash5 opened 11 months ago

mathuryash5 commented 11 months ago

Hello,

I am trying to use the step function(with the transformers and accelerate library) while passing the closure.

The step function has a decorator @torch.no_grad() and thus we specify enable_grad while calling the closure to compute gradients. How does the second call to closure() work? I have tried that and get the following error which sort of makes sense considering gradients will not be computed: RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Here is the closure function I use:

def closure():
    tmp_ouput= model(**batch)
    tmp_loss = tmp_ouput.loss
    tmp_loss = tmp_loss / args.gradient_accumulation_steps
    accelerator.backward(tmp_loss)
    return accelerator 
stale[bot] commented 8 months ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.