NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.6k stars 255 forks source link

NaN loss issues when I switch to the Transformer Engine TransformerLayer from pytorch layer #953

Open jasonkrone opened 1 week ago

jasonkrone commented 1 week ago

Summary I'm hitting a NaN loss issue when I use the TransformerLayer in place of a pytorch transformer layer I wrote.

Details I'm using the nvcr.io/nvidia/pytorch:24.04-py3 docker container. I train with pytorch FSDP and use bfloat16 mixed precision.

Question Has the TransformerEngine team trained a model with the ‎TELlamaDecoderLayer‎ to ensure that everything works as expected? If so, could you share this example as my use case is very similar.

Code Here's the code I wrote to wrap the TransformerLayer such that it uses the ROPE embeddings. This is the class I swapped in for my model.

class TransformerLayerWithPOS(TransformerLayer):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        rope = RotaryPositionEmbedding(kwargs["hidden_size"] // kwargs["num_attention_heads"])
        self.rope_freqs = rope(max_seq_len=kwargs["seq_length"]).cuda()

    def forward(self, hidden_states):
        """
        Custom forward to make sure we only pass relevant arguments to the
        forward pass of the `TransformerLayer`. Also, make sure the output
        format matches the output of the HF's `LlamaDecoderLayer`.
        """
        return super().forward(hidden_states, rotary_pos_emb=self.rope_freqs)

In addition, here are the kwargs I send to the transformer layer.

        device = "meta" if config.use_meta_device else "cuda"
        return {
            "device": device,
            "params_dtype": torch.float32, 
            "hidden_size": config.d_model,
            "ffn_hidden_size": config.d_hidden,
            "num_attention_heads": config.n_heads,
            "self_attn_mask_type": "causal",
            "normalization": "RMSNorm",
            "bias": False, 
            "activation": "swiglu",
            "attn_input_format": "bshd", 
            "fuse_wgrad_accumulation": False, 
            "seq_length": config.max_len,
            "fuse_qkv_params": True,
        }

Learning Curve See the attached learning curve which displays the NaN issue, which occurs around step #350.

Screenshot 2024-06-21 at 10 34 24 AM