lucidrains / PaLM-rlhf-pytorch

Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Basically ChatGPT but with PaLM
MIT License
7.71k stars 669 forks source link

norm.gamma not used during backprop #46

Closed conceptofmind closed 1 year ago

conceptofmind commented 1 year ago

Hi @lucidrains ,

I am almost ready to deploy the distributed training run. One thing I noticed is that norm.gamma is an unused parameter.

class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer("beta", torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

This throws an error during distributed training.

This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`, and by 
making sure all `forward` function outputs participate in calculating loss. 

Find unused parameters:

    for _ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader), return_loss = True)
        accelerator.backward(loss / GRADIENT_ACCUMULATE_EVERY)

    for name, param in model.named_parameters():
        if param.grad is None:
            print("NONE")
            print(name)

Output:

NONE
norm.gamma

This is resolved by setting find_unused_parameters=True at the cost of double forward.

I was wondering if you had any idea why this may be the case or if there is a proper way to resolve this issue.

I greatly appreciate your input as always.

Thank you,

Enrico

conceptofmind commented 1 year ago

@dmahan93 noticed that embeds are not fed to logits. This may be the issue.

Logits takes in x:

        # final norm

        embeds = self.norm(x)

        if return_only_embedding:
            return embeds

        # to logits

        logits = self.to_logits(x)

Should it be logits takes in embeds?

        # final norm

        embeds = self.norm(x)

        if return_only_embedding:
            return embeds

        # to logits

        logits = self.to_logits(embeds)

Thank you,

Enrico

lucidrains commented 1 year ago

@conceptofmind @dmahan93 oh yes, thanks for catching this! put in a quick fix