A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
Summary
I'm hitting a NaN loss issue when I use the TransformerLayer in place of a pytorch transformer layer I wrote.
Details
I'm using the nvcr.io/nvidia/pytorch:24.04-py3 docker container.
I train with pytorch FSDP and use bfloat16 mixed precision.
Question
Has the TransformerEngine team trained a model with the TELlamaDecoderLayer to ensure that everything works as expected? If so, could you share this example as my use case is very similar.
Code
Here's the code I wrote to wrap the TransformerLayer such that it uses the ROPE embeddings. This is the class I swapped in for my model.
class TransformerLayerWithPOS(TransformerLayer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
rope = RotaryPositionEmbedding(kwargs["hidden_size"] // kwargs["num_attention_heads"])
self.rope_freqs = rope(max_seq_len=kwargs["seq_length"]).cuda()
def forward(self, hidden_states):
"""
Custom forward to make sure we only pass relevant arguments to the
forward pass of the `TransformerLayer`. Also, make sure the output
format matches the output of the HF's `LlamaDecoderLayer`.
"""
return super().forward(hidden_states, rotary_pos_emb=self.rope_freqs)
In addition, here are the kwargs I send to the transformer layer.
Summary I'm hitting a NaN loss issue when I use the TransformerLayer in place of a pytorch transformer layer I wrote.
Details I'm using the nvcr.io/nvidia/pytorch:24.04-py3 docker container. I train with pytorch FSDP and use bfloat16 mixed precision.
Question Has the TransformerEngine team trained a model with the TELlamaDecoderLayer to ensure that everything works as expected? If so, could you share this example as my use case is very similar.
Code Here's the code I wrote to wrap the TransformerLayer such that it uses the ROPE embeddings. This is the class I swapped in for my model.
In addition, here are the kwargs I send to the transformer layer.
Learning Curve See the attached learning curve which displays the NaN issue, which occurs around step #350.