EleutherAI / lm-evaluation-harness

A framework for few-shot evaluation of language models.
https://www.eleuther.ai
MIT License
7.08k stars 1.9k forks source link

Loaded AutoModelForCausalLM has no grad_fn #2523

Closed Yufei-Gu-451 closed 1 hour ago

Yufei-Gu-451 commented 2 hours ago

I am working on understanding the model gradients in LM inference. I attempted to capture the gradient in _loglikelihood_tokens methods with modified Huggingface Model class and loaded AutoModelForCausalLM from pretrained huggingface checkpoints.

# Compute Cross-Entropy Loss
logits_reshaped = logits.view(-1, logits.size(-1))  # [seq, vocab]
targets = cont_toks.view(-1)  # [seq]

loss = F.cross_entropy(logits_reshaped, targets, ignore_index=-1)
loss.requires_grad = True

# Zero the gradients of the model and enable input requie_grads
self.model.zero_grad()
self.model.enable_input_require_grads()

# Enable gradient checkpointing with custom settings
self.model.gradient_checkpointing_enable()  # Enables gradient checkpointing
self.model.config.gradient_checkpointing_kwargs = {"use_reentrant": False}

for name, param in self.model.named_parameters():
      print(param.grad_fn)

# Compute gradients using autograd.grad
grads = torch.autograd.grad(loss, self.model.parameters(), 
                                                retain_graph=False,
                                                allow_unused=True)

However, all model parameters have no grad_fn (printed None). I attempted on Llama and Mistral and get the same results.

Is there anyway I can load some AutoModelForCausalLM models with enabled grad_fn?

baberabb commented 2 hours ago

Hi! did you overload the _model_call method? : https://github.com/EleutherAI/lm-evaluation-harness/blob/5680a2e6b5cf1a1621d8ff68d3d0e83e8b2731d3/lm_eval/models/huggingface.py#L846

We also set it to eval mode here: https://github.com/EleutherAI/lm-evaluation-harness/blob/5680a2e6b5cf1a1621d8ff68d3d0e83e8b2731d3/lm_eval/models/huggingface.py#L202-L204

Yufei-Gu-451 commented 1 hour ago

Thank you for your kind response. My issue is fixed by enabling torch.grad() in the _model_call method (simply comment out lines 202 to 204 is not enough).