SparkJiao / llama-pipeline-parallel

A prototype repo for hybrid training of pipeline parallel and distributed data parallel with comments on core code snippets. Feel free to copy code and launch discussions about the problems you have encoured.
45 stars 2 forks source link

RuntimeError: element 1 of tensors does not require grad and does not have a grad_fn #5

Closed OAfzal closed 6 months ago

OAfzal commented 9 months ago

So, I have been trying to train the LLAMA 7B model on the alpaca dataset. I debugged, and element 1 seems to be the 4D attention mask I pass to the model from the torch dataset. FYI: I created both the position_ids and attention_mask in the custom torch dataset so no tensors are being created inside the forward. Do you know what may be triggering this. I did not make any changes to the model file you provided

SparkJiao commented 8 months ago

Sorry for missing this issue. Have you solved it?

I think 4D attention mask is normal, which should be in shape [batch_size, 1, q_len, kv_len].

drewanye commented 8 months ago

I met the same problem. I found the problem occurs in adding some operations in forward. My case is that torch autograd failed when converting attention mask from [bsz, seq_len] to [bsz, 1, seq_len, seq_len] the forward function of EmbeddingPipe.

class EmbeddingPipe(torch.nn.Embedding):
    def forward(self, args):
        input_ids, attention_mask, position_ids = args
        inputs_embeds = super().forward(input_ids)
        # this raise Error
        attention_mask = _prepare_decoder_attention_mask(attention_mask, input_ids.shape, inputs_embeds, 0)
        return inputs_embeds, attention_mask, position_ids
SparkJiao commented 8 months ago

@drewanye Similarly, I would recommend to use the other implementation: https://github.com/SparkJiao/llama-pipeline-parallel/blob/main/models/llama_ds_mp_wrap.py#L128-L132

Currently, I'm also not sure the internal reason. But I think the principle to avoid this is try to avoid creating non-grad tensors during the forward pass. In other way, you could simply create them and pass them as inputs.

SparkJiao commented 6 months ago

Closed. Reopen it if you have further questions.