huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.09k stars 27.03k forks source link

RecurrentGemma not compatible with autocast / AMP training #30830

Closed xplip closed 5 months ago

xplip commented 6 months ago

System Info

Who can help?

@ArthurZucker

Information

Tasks

Reproduction

import torch
from transformers import RecurrentGemmaConfig, RecurrentGemmaForCausalLM

def main():
    V = 288
    B = 16
    T = 300
    device = "cuda"

    config = RecurrentGemmaConfig(
        vocab_size=V,
        num_hidden_layers=12,
        hidden_size=1024,
        num_attention_heads=8,
        intermediate_size=6144,
        attention_window_size=T,
    )
    model = RecurrentGemmaForCausalLM._from_config(config, torch_dtype=torch.float32).to(device)

    autocast_settings = [
        {"dtype": torch.float16, "enabled": True},
        {"dtype": torch.bfloat16, "enabled": True},
        {"enabled": False},
    ]

    for autocast_setting in autocast_settings:
        print(f"\nRunning with autocast setting: {autocast_setting}:")
        try:
            with torch.cuda.amp.autocast(**autocast_setting):
                outputs = model(input_ids=torch.randint(0, V, (B, T), device=device))
                print(outputs.logits.shape)
        except RuntimeError as e:
            print(e)

if __name__ == "__main__":
    main()

Expected behavior

The script should run without errors with autocast enabled, as otherwise training with AMP is not available.

Output of the script above:

Running with autocast setting: {'dtype': torch.float16, 'enabled': True}
Index put requires the source and destination dtypes match, got Float for the destination and Half for the source.

Running with autocast setting: {'dtype': torch.bfloat16, 'enabled': True}
Index put requires the source and destination dtypes match, got Float for the destination and BFloat16 for the source.

Running with autocast setting: {'enabled': False}
torch.Size([16, 300, 288])

Expected output:

Running with autocast setting: {'dtype': torch.float16, 'enabled': True}
torch.Size([16, 300, 288])

Running with autocast setting: {'dtype': torch.bfloat16, 'enabled': True}
torch.Size([16, 300, 288])

Running with autocast setting: {'enabled': False}
torch.Size([16, 300, 288])
ArthurZucker commented 5 months ago

Will merge the PR!