Open SparkJiao opened 1 year ago
Hello,
when I try to use flash attention, I have encountered the following problem:
│ /export/home2/fangkai/merit-v2/trainer_base_ds_mp.py:346 in main │ │ │ │ 343 │ │ │ logger.info("Resuming training from the latest checkpoint: │ │ 344 │ │ │ continue_from_global_step = int(checkpoint.split('-')[-1]) │ │ 345 │ │ │ │ ❱ 346 │ │ global_step, tr_loss = train(cfg, model_pipe, tokenizer, conti │ │ 347 │ │ logger.info(" global_step = %s, average loss = %s", global_ste │ │ 348 │ │ 349 │ │ │ │ /export/home2/fangkai/merit-v2/trainer_base_ds_mp.py:236 in train │ │ │ │ 233 │ │ │ │ │ continue │ │ 234 │ │ │ │ │ │ 235 │ │ │ │ model.train() │ │ ❱ 236 │ │ │ │ loss = model.train_batch(data_iter=sub_train_dataloade │ │ 237 │ │ │ │ global_step += 1 │ │ 238 │ │ │ │ │ │ 239 │ │ │ │ tr_loss += loss.item() │ │ │ │ /export/home2/fangkai/anaconda3/envs/torch2.0/lib/python3.9/site-packages/de │ │ epspeed/runtime/pipe/engine.py:336 in train_batch │ │ │ │ 333 │ │ sched = schedule.TrainSchedule(micro_batches=self.micro_batch │ │ 334 │ │ │ │ │ │ │ │ │ stages=self.num_stages, │ │ 335 │ │ │ │ │ │ │ │ │ stage_id=self.stage_id) │ │ ❱ 336 │ │ self._exec_schedule(sched) │ │ 337 │ │ self.agg_train_loss = self._aggregate_total_loss() │ │ 338 │ │ │ │ 339 │ │ self.timers('train_batch').stop() │ │ │ │ /export/home2/fangkai/anaconda3/envs/torch2.0/lib/python3.9/site-packages/de │ │ epspeed/runtime/pipe/engine.py:1307 in _exec_schedule │ │ │ │ 1304 │ │ │ │ │ │ 1305 │ │ │ │ # Equivalent to: self._exec_forward_pass(buffer_id=0) │ │ 1306 │ │ │ │ self._exec_instr = MethodType(self._INSTRUCTION_MAP[t │ │ ❱ 1307 │ │ │ │ self._exec_instr(**cmd.kwargs) │ │ 1308 │ │ │ │ /export/home2/fangkai/anaconda3/envs/torch2.0/lib/python3.9/site-packages/de │ │ epspeed/runtime/pipe/engine.py:996 in _exec_send_grads │ │ │ │ 993 │ │ │ │ │ if not buffer.is_floating_point(): │ │ 994 │ │ │ │ │ │ assert buffer.grad is None │ │ 995 │ │ │ │ │ │ continue │ │ ❱ 996 │ │ │ │ │ assert buffer.grad is not None │ │ 997 │ │ │ │ │ p2p.send(buffer.grad, self.prev_stage) │ │ 998 │ │ │ │ 999 │ │ # We can free up the input buffer now │ ╰──────────────────────────────────────────────────────────────────────────────╯ AssertionError
I also test it by using the torch.nn.functional.scaled_dot_product_attention, which implements flash attention in torch2.0, but I met the same problem. May I know if you have encountered with the problem?
Thanks for your help very much!
Best, Fangkai
Hello,
when I try to use flash attention, I have encountered the following problem:
I also test it by using the torch.nn.functional.scaled_dot_product_attention, which implements flash attention in torch2.0, but I met the same problem. May I know if you have encountered with the problem?
Thanks for your help very much!
Best, Fangkai