mingkaid / rl-prompt

Accompanying repo for the RLPrompt paper
MIT License
295 stars 54 forks source link

Questions on the Gradients of LLM #41

Open Schwartz-Zha opened 10 months ago

Schwartz-Zha commented 10 months ago

As I understand, one of the core contributions claimed in the paper is that the whole training does not require the derivatives of LLM, so it saves a lot of resources.

But how is this enforced in the code?

In LMAdaptorModel,

for param in self.generator.model.parameters():
            param.requires_grad = False

In PromptedClassificationReward, there is a no_grad decorator:

@torch.no_grad()
    def _get_logits(
        self,
        texts: List[str]
    ) -> torch.Tensor:

But my experiments show that, both methods cannot really forbid the computation of gradients.

Denote some network blocks as function $g$, and $g$ is restricted by no_grad or requires_grad = False. And there are some network blocks $f$ attached before $g$, so the whole networks looks like $$g(f(x))$$.

However, $f$ does require gradients as $f$ need to be updated. And my experiments show that the gradients of $g$ will be computed in this case, because there is no other way to compute the gradients of $f$. So no_grad/requires_grad = False will have no effect. The gradients will still be computed.

I wonder, in this case, how exactly does the author arrange to make the gradient computation of LLM never happens. Because the training runs too fast, this has no possibility to happen.