Using distributed or parallel set-up in script?: False
Who can help?
@ArthurZucker
Information
[X] The official example scripts
[X] My own modified scripts
Tasks
[X] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
[X] My own task or dataset (give details below)
Reproduction
import torch
from transformers import RecurrentGemmaConfig, RecurrentGemmaForCausalLM
def main():
V = 288
B = 16
T = 300
device = "cuda"
config = RecurrentGemmaConfig(
vocab_size=V,
num_hidden_layers=12,
hidden_size=1024,
num_attention_heads=8,
intermediate_size=6144,
attention_window_size=T,
)
model = RecurrentGemmaForCausalLM._from_config(config, torch_dtype=torch.float32).to(device)
autocast_settings = [
{"dtype": torch.float16, "enabled": True},
{"dtype": torch.bfloat16, "enabled": True},
{"enabled": False},
]
for autocast_setting in autocast_settings:
print(f"\nRunning with autocast setting: {autocast_setting}:")
try:
with torch.cuda.amp.autocast(**autocast_setting):
outputs = model(input_ids=torch.randint(0, V, (B, T), device=device))
print(outputs.logits.shape)
except RuntimeError as e:
print(e)
if __name__ == "__main__":
main()
Expected behavior
The script should run without errors with autocast enabled, as otherwise training with AMP is not available.
Output of the script above:
Running with autocast setting: {'dtype': torch.float16, 'enabled': True}
Index put requires the source and destination dtypes match, got Float for the destination and Half for the source.
Running with autocast setting: {'dtype': torch.bfloat16, 'enabled': True}
Index put requires the source and destination dtypes match, got Float for the destination and BFloat16 for the source.
Running with autocast setting: {'enabled': False}
torch.Size([16, 300, 288])
System Info
transformers
version: 4.40.2Who can help?
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Expected behavior
The script should run without errors with autocast enabled, as otherwise training with AMP is not available.
Output of the script above:
Expected output: