HuangLK / transpeeder

train llama on a single A100 80G node using 🤗 transformers and 🚀 Deepspeed Pipeline Parallelism
Apache License 2.0
208 stars 18 forks source link

Flash attention integration failed #36

Open SparkJiao opened 1 year ago

SparkJiao commented 1 year ago

Hello,

when I try to use flash attention, I have encountered the following problem:

│ /export/home2/fangkai/merit-v2/trainer_base_ds_mp.py:346 in main             │
│                                                                              │
│   343 │   │   │   logger.info("Resuming training from the latest checkpoint: │
│   344 │   │   │   continue_from_global_step = int(checkpoint.split('-')[-1]) │
│   345 │   │                                                                  │
│ ❱ 346 │   │   global_step, tr_loss = train(cfg, model_pipe, tokenizer, conti │
│   347 │   │   logger.info(" global_step = %s, average loss = %s", global_ste │
│   348                                                                        │
│   349                                                                        │
│                                                                              │
│ /export/home2/fangkai/merit-v2/trainer_base_ds_mp.py:236 in train            │
│                                                                              │
│   233 │   │   │   │   │   continue                                           │
│   234 │   │   │   │                                                          │
│   235 │   │   │   │   model.train()                                          │
│ ❱ 236 │   │   │   │   loss = model.train_batch(data_iter=sub_train_dataloade │
│   237 │   │   │   │   global_step += 1                                       │
│   238 │   │   │   │                                                          │
│   239 │   │   │   │   tr_loss += loss.item()                                 │
│                                                                              │
│ /export/home2/fangkai/anaconda3/envs/torch2.0/lib/python3.9/site-packages/de │
│ epspeed/runtime/pipe/engine.py:336 in train_batch                            │
│                                                                              │
│    333 │   │   sched = schedule.TrainSchedule(micro_batches=self.micro_batch │
│    334 │   │   │   │   │   │   │   │   │      stages=self.num_stages,        │
│    335 │   │   │   │   │   │   │   │   │      stage_id=self.stage_id)        │
│ ❱  336 │   │   self._exec_schedule(sched)                                    │
│    337 │   │   self.agg_train_loss = self._aggregate_total_loss()            │
│    338 │   │                                                                 │
│    339 │   │   self.timers('train_batch').stop()                             │
│                                                                              │
│ /export/home2/fangkai/anaconda3/envs/torch2.0/lib/python3.9/site-packages/de │
│ epspeed/runtime/pipe/engine.py:1307 in _exec_schedule                        │
│                                                                              │
│   1304 │   │   │   │                                                         │
│   1305 │   │   │   │   # Equivalent to: self._exec_forward_pass(buffer_id=0) │
│   1306 │   │   │   │   self._exec_instr = MethodType(self._INSTRUCTION_MAP[t │
│ ❱ 1307 │   │   │   │   self._exec_instr(**cmd.kwargs)                        │
│   1308                                                                       │
│                                                                              │
│ /export/home2/fangkai/anaconda3/envs/torch2.0/lib/python3.9/site-packages/de │
│ epspeed/runtime/pipe/engine.py:996 in _exec_send_grads                       │
│                                                                              │
│    993 │   │   │   │   │   if not buffer.is_floating_point():                │
│    994 │   │   │   │   │   │   assert buffer.grad is None                    │
│    995 │   │   │   │   │   │   continue                                      │
│ ❱  996 │   │   │   │   │   assert buffer.grad is not None                    │
│    997 │   │   │   │   │   p2p.send(buffer.grad, self.prev_stage)            │
│    998 │   │                                                                 │
│    999 │   │   # We can free up the input buffer now                         │
╰──────────────────────────────────────────────────────────────────────────────╯
AssertionError

I also test it by using the torch.nn.functional.scaled_dot_product_attention, which implements flash attention in torch2.0, but I met the same problem. May I know if you have encountered with the problem?

Thanks for your help very much!

Best, Fangkai