Closed RakshaPRao closed 3 months ago
Hi, you need to provide you pytorch version as well as whether you installed flash-attn2 as well as your yaml file, otherwise difficult to help you out.
Hi, pytorch version = 2.0.1. flash-attn2 is not installed and the yaml is
accum_count: 8
accum_steps: 0
adam_beta1: 0.9
adam_beta2: 0.998
batch_size: 4096
batch_size_multiple: 1
batch_type: tokens
bucket_size: 32768
decay_method: noam
decoder_type: transformer
dropout: 0.2
early_stopping: 10
encoder_type: transformer
feat_merge: sum
heads: 12
hidden_size: 768
keep_checkpoint: 20
label_smoothing: 0.1
layers: 6
learning_rate: 2.0
max_generator_batches: 0
max_grad_norm: 0.0
n_src_feats: 1
normalization: tokens
optim: adam
param_init: 0.0
param_init_glorot: 'true'
pool_factor: 8192
position_encoding: 'true'
queue_size: 1024
report_every: 100
save_checkpoint_steps: 5000
seed: 1234
share_vocab: true
src_feats_defaults: N
src_seq_length: 600
src_vocab_size: 38000
tgt_seq_length: 600
train_steps: 1000000
transformer_ff: 3072
valid_batch_size: 16
valid_steps: 5000
warmup_steps: 8000
word_vec_size: 768
Thanks!
You need 2.1 or 2.2 Sdpa is buggy with 2.0.1
Distributed training with source features fails with the error.