Open zhangqijun opened 2 months ago
in sigle gpu mode,I success run the train by RTX3090.but it took too long。 in ddp mode,we got OOM in LlamaForCausalLM = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, broadcast_buffers=False, ) .
same issue. I found that it seems related with the initialization of the DDP.
in sigle gpu mode,I success run the train by RTX3090.but it took too long。 in ddp mode,we got OOM in LlamaForCausalLM = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, broadcast_buffers=False, ) .