bytedance / byteps

A high performance and generic framework for distributed DNN training
Other
3.63k stars 488 forks source link

How to use gradient accumulate in BytePS torch DDP? #417

Open wuyujiji opened 3 years ago

wuyujiji commented 3 years ago

Did you have the demo for gradient accumulate in BytePS torch DDP? I can not find it in byteps/torch/example.

aDecisionTree commented 3 years ago

I'm also interested in this~

ymjiang commented 3 years ago

bps.DistributedOptimizer supports gradient accumulation with the backward_passes_per_step option.

bps.DistributedDataParallel does not support it for now. We will add this feature.

wuyujiji commented 3 years ago

Could you please share the entire gradient accumulate demo for bps.DistributedOptimizer?

ymjiang commented 3 years ago

Here is a general workflow:

optimizer = bps.DistributedOptimizer(optimizer)
optimizer.set_backward_passes_per_step(accumulation_steps)
model.zero_grad()                               
for i, (inputs, labels) in enumerate(training_set):
    predictions = model(inputs)                   
    loss = loss_function(predictions, labels)       
    loss = loss / accumulation_steps               # optional
    loss.backward()
    if (i+1) % accumulation_steps == 0:          
        optimizer.step()                            
        model.zero_grad()                         

We will consider adding an example later.

wuyujiji commented 3 years ago

Here is a general workflow:

optimizer = bps.DistributedOptimizer(optimizer)
optimizer.set_backward_passes_per_step(accumulation_steps)
model.zero_grad()                               
for i, (inputs, labels) in enumerate(training_set):
    predictions = model(inputs)                   
    loss = loss_function(predictions, labels)       
    loss = loss / accumulation_steps               # optional
    loss.backward()
    if (i+1) % accumulation_steps == 0:          
        optimizer.step()                            
        model.zero_grad()                         

We will consider adding an example later.

Thanks for replying quickly! If I want to use torch.cuda.amp in above code, how did I further add it?