I'm working on the 32k long text SFT for Qwen2 72b. When I set seq_parallel_world_size to greater than one and use_varlen_attn to true, an error occurs.
After checking, the error message is an assert error, indicating that the length of my input_ids sequence should be divisible by seq_parallel_world_size. Once I padded the sequence to the appropriate length, this error was resolved. However, after several iterations during training, the loss becomes NaN.
Here are my specific config:
use_varlen_attn = True
`prompt_template = PROMPT_TEMPLATE.qwen_chat
max_length = 32768
pack_to_max_length = True
I'm working on the 32k long text SFT for Qwen2 72b. When I set seq_parallel_world_size to greater than one and use_varlen_attn to true, an error occurs. After checking, the error message is an assert error, indicating that the length of my input_ids sequence should be divisible by seq_parallel_world_size. Once I padded the sequence to the appropriate length, this error was resolved. However, after several iterations during training, the loss becomes NaN.
Here are my specific config:
use_varlen_attn = True
`prompt_template = PROMPT_TEMPLATE.qwen_chat max_length = 32768 pack_to_max_length = Trueparallel
sequence_parallel_size = 4
Scheduler & Optimizer
batch_size = 1 # per_device accumulative_counts = 32 accumulative_counts *= sequence_parallel_size dataloader_num_workers = 4 max_epochs = 2 optim_type = AdamW lr = 2e-6 betas = (0.9, 0.999) weight_decay = 0 max_norm = 1 # grad clip warmup_ratio = 0.1`