BaguaSys / bagua

Bagua Speeds up PyTorch
https://tutorials-8ro.pages.dev/
MIT License
872 stars 83 forks source link

I use bagua with the phenomenon as follows ( bagua.broadcast(ps, 0, comm=comm) ) #333

Closed lixiangMindSpore closed 2 years ago

lixiangMindSpore commented 2 years ago

Describe the bug A clear and concise description of what the bug is. image

Environment

Reproducing

Please provide a minimal working example. This means the runnable code.

Please also write what exact commands are required to reproduce your results.

Additional context Add any other context about the problem here.

lixiangMindSpore commented 2 years ago
comm = bagua.communication._get_default_group().get_global_communicator()  ###

# Broadcast init parameters
for ps in backbone.parameters():
    bagua.broadcast(ps, 0, comm=comm)     ###
shjwudp commented 2 years ago

This error is generally that the model is not on the specified GPU device. You can check whether the GPU where the model is located is equal to bagua.get_local_rank().

If there is still not work, please provide the minimal bug producing example script.

lixiangMindSpore commented 2 years ago

import os import argparse import time import math import torch import torch.distributed as dist import torch.nn.functional as F import torch.utils.data.distributed from torch import nn from torch.utils.tensorboard import SummaryWriter from backbones.model_irse import IR_SE_50, IR_SE_101 from config import config as cfg from utils import * from dataset import MXFaceDataset, DataLoaderX

from dataset import RecgDataset_mask as RecgDataset # 口罩增强~

from partial_classifier import DistSampleClassifier from partial_loss import MarginSoftmax from sgd import SGD from torchsummary import summary import bagua.torch_api as bagua ### from bagua.torch_api.algorithms import gradient_allreduce ### from common_utils import find_free_port

torch.backends.cudnn.benchmark = True

def should_distribute(): return dist.is_available() and world_size >= 1

def is_distributed(): return dist.is_available() and dist.is_initialized()

def _init_bagua_env(rank, env):

init bagua distributed process group

torch.cuda.set_device(bagua.get_local_rank())
bagua.init_process_group()

def main(local_rank, rank, world_size, cfg):

dataloader

print('loading data...')
comm = bagua.communication._get_default_group().get_global_communicator()  ###
trainset = RecgDataset(cfg)

train_sampler = torch.utils.data.distributed.DistributedSampler(trainset, shuffle=True)
train_loader = DataLoaderX(local_rank=local_rank,
                           dataset=trainset,
                           batch_size=cfg.batch_size,
                           sampler=train_sampler,
                           num_workers=0,
                           pin_memory=True,
                           drop_last=True)

# model
print('loading model...')
backbone = IR_SE_50(cfg.input_size).to(local_rank)
# backbone = IR_SE_101(cfg.input_size)
# Memory classifer
dist_sample_classifer = DistSampleClassifier(trainset.classes, rank=rank, local_rank=local_rank,
                                             world_size=world_size)
# Margin softmax
margin_softmax = MarginSoftmax(s=64.0, m=0.4)

# Optimizer for backbone and classifer
# optimizer = SGD([{'params': backbone.parameters()}, {'params': dist_sample_classifer.parameters()}],
#                lr=0.1, momentum=0.9, weight_decay=cfg.weight_decay, rescale=world_size)

# Broadcast init parameters
# for ps in backbone.parameters():
#    dist.broadcast(ps, 0)

backbone_path = os.path.join(cfg.model_save + 'backbone/')
head_path = os.path.join(cfg.model_save + 'head/')
log_path = os.path.join(cfg.log_save + 'shows/')

cfg.model_resume = cfg.model_save
backbone_resume = os.path.join(cfg.model_resume + 'backbone/')
head_resume = os.path.join(cfg.model_resume + 'head/')

# if cfg.resume and os.path.isdir(backbone_resume) and os.path.isdir(head_resume):
if cfg.resume and os.path.isdir(backbone_resume):
    print('resume~~~~~~~~~~~~~~~~~~~~~~~~~~')
    backbone_list = os.listdir(backbone_resume)
    if backbone_list:
        # pre_flags = [eval(x.split('Epoch_')[1].split('_Time')[0]) for x in backbone_list]
        # tar_flag = max(pre_flags)
        # print(tar_flag)
        # backbone_ckpt = torch.load(backbone_resume + '/' + backbone_list[0], map_location=torch.device('cpu'))
        # backbone_ckpt = torch.load(backbone_path + '/' + str(tar_flag) + '_backbone.pth')
        print(backbone_resume + '/' + backbone_list[0])
        backbone_ckpt = torch.load(backbone_resume + '/' + backbone_list[0], map_location=torch.device('cpu'))
        backbone.load_state_dict(backbone_ckpt['backbone'])
        print('load backbone ~')
        # optimizer.load_state_dict(backbone_ckpt['optimizer'])
        # print(optimizer.param_groups[0]['lr'])
        # start_epoch = backbone_ckpt['epoch'] + 1
        start_epoch = 1
        fg = 0
        if rank == 0 and fg:
            head_ckpt = torch.load(head_resume + '/' + 'Head_Arcface_Epoch_27_Time_2021-07-04-05-58' + '_head0.pth')
            dist_sample_classifer.load_state_dict(head_ckpt['head'])
        if rank == 1 and fg:
            head_ckpt = torch.load(head_resume + '/' + 'Head_Arcface_Epoch_27_Time_2021-07-04-05-58' + '_head1.pth')
            dist_sample_classifer.load_state_dict(head_ckpt['head'])
        if rank == 2 and fg:
            head_ckpt = torch.load(head_resume + '/' + 'Head_Arcface_Epoch_27_Time_2021-07-04-05-58' + '_head2.pth')
            dist_sample_classifer.load_state_dict(head_ckpt['head'])
        if rank == 3 and fg:
            head_ckpt = torch.load(head_resume + '/' + 'Head_Arcface_Epoch_27_Time_2021-07-04-05-58' + '_head3.pth')
            dist_sample_classifer.load_state_dict(head_ckpt['head'])

else:
    start_epoch = 0
    print("Train from Scratch")
print("=" * 60)

backbone = backbone.to(local_rank)

optimizer = SGD([{'params': backbone.parameters()}, {'params': dist_sample_classifer.parameters()}],
                lr=cfg.lr, momentum=0.9, weight_decay=cfg.weight_decay, rescale=world_size)

# Broadcast init parameters
for ps in backbone.parameters():
    bagua.broadcast(ps, 0, comm)     ###

# bagua
# backbone = torch.nn.SyncBatchNorm.convert_sync_batchnorm(backbone)
algorithm = gradient_allreduce.GradientAllReduceAlgorithm()
backbone = backbone.with_bagua([optimizer], algorithm)

# Lr scheduler
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                              lr_lambda=cfg.lr_func)
os.makedirs(log_path, exist_ok=True)
if local_rank == 0:
    writer = SummaryWriter(log_dir=log_path)

print('trainning...')
global_step = 0
n_epochs = cfg.num_epoch
# NUM_EPOCH_WARM_UP = n_epochs // 25
NUM_EPOCH_WARM_UP = 5
NUM_BATCH_WARM_UP = NUM_EPOCH_WARM_UP * len(train_loader)
backbone.train()
for epoch in range(start_epoch, n_epochs):
    train_sampler.set_epoch(epoch)

    for step, (img, label) in enumerate(train_loader):

        if (epoch + 1 <= NUM_EPOCH_WARM_UP) and (
                global_step + 1 <= NUM_BATCH_WARM_UP) and 1:  # adjust LR for each training batch during warm up
            warm_up_lr(global_step + 1, NUM_BATCH_WARM_UP, cfg.lr, optimizer)

        total_label, norm_weight = dist_sample_classifer.prepare(label, optimizer)
        # print('total_label:', total_label.shape)
        # print('norm_weight:', norm_weight.shape)
        features = backbone(img)  # feature 在内部归一化了

        # Features all-gather
        total_features = torch.zeros(features.size()[0] * world_size, cfg.embedding_size, device=local_rank)
        dist.all_gather(list(total_features.chunk(world_size, dim=0)), features.data)
        total_features.requires_grad = True

        # Calculate logits
        # print('&' * 60)
        # print('total_features:', total_features.shape)
        # print('norm_weight:', norm_weight.shape)
        # print('total_label:', total_label.shape)
        logits = dist_sample_classifer(total_features, norm_weight)  # cos =
        # print('logits1:', logits.shape)
        # print('logits:', logits.shape)
        # print('&' * 60)
        logits = margin_softmax(logits, total_label)
        # print('logits2:', logits.shape)
        # total_logits = torch.zeros(logits.size()[0], len(DataLoaderX), device=local_rank)
        # dist.all_gather(list(total_logits.chunk(world_size, dim=0)),
        #                 logits.data)

        with torch.no_grad():
            max_fc = torch.max(logits, dim=1, keepdim=True)[0]
            # print('max_fc:', max_fc.shape)
            ###dist.all_reduce(max_fc, dist.ReduceOp.MAX)
            recv_max_fc = torch.zeros_like(max_fc)
            bagua.allreduce(max_fc, recv_max_fc, bagua.ReduceOp.MAX, comm=comm)   ###
            # print('#'*10, max_fc)
            # Calculate exp(logits) and all-reduce
            logits_exp = torch.exp(logits - recv_max_fc)
            logits_sum_exp = logits_exp.sum(dim=1, keepdims=True)
            recv_logits_sum_exp = torch.zeros_like(logits_sum_exp)
            ###dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM)
            bagua.allreduce(logits_sum_exp, recv_logits_sum_exp, bagua.ReduceOp.SUM, comm=comm)  ###

            # Calculate prob
            logits_exp.div_(logits_sum_exp)

            # Get one-hot
            grad = logits_exp
            index = torch.where(total_label != -1)[0]
            one_hot = torch.zeros(index.size()[0], grad.size()[1], device=grad.device)
            one_hot.scatter_(1, total_label[index, None], 1)

            # Calculate loss
            loss = torch.zeros(grad.size()[0], 1, device=grad.device)
            loss[index] = grad[index].gather(1, total_label[index, None])
            recv_loss = torch.zeros_like(loss)
            ###dist.all_reduce(loss, dist.ReduceOp.SUM)
            bagua.allreduce(loss, recv_loss, bagua.ReduceOp.SUM, comm=comm)  ###
            loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1)

            # Calculate grad
            grad[index] -= one_hot
            grad.div_(features.size()[0])

        logits.backward(grad)
        if total_features.grad is not None:
            total_features.grad.detach_()
        x_grad = torch.zeros_like(features)

        # Feature gradient all-reduce
        ###dist.reduce_scatter(
            ###x_grad, list(total_features.grad.chunk(world_size, dim=0)))
        bagua.reduce_scatter(list(total_features.grad.chunk(world_size, dim=0)), x_grad, comm=comm)  ###
        x_grad.mul_(world_size)
        # Backward backbone
        features.backward(x_grad)
        optimizer.step()

        # Update classifer
        dist_sample_classifer.update()
        optimizer.zero_grad()

        tm = time.asctime().split()[-2]
        if rank == 0 and global_step % cfg.disp_freq == 0:
            writer.add_scalar('loss', loss_v, global_step)
            print('\nEpoch:{}/{} Batch:{}/{}\t'
                  'Loss:{loss:.4f}\t'
                  'lr: {lr:.4f}\t'
                  'TimeNow:{TimeNow}'.format(
                epoch, n_epochs, (step + 1) % len(train_loader), len(train_loader),
                loss=loss_v, lr=optimizer.param_groups[0]['lr'], TimeNow=tm))

        global_step += 1
    scheduler.step()

    if rank == 0:
        os.makedirs(backbone_path, exist_ok=True)
        state = {'backbone': backbone.module.state_dict(),
                 'optimizer': optimizer.state_dict(), 'epoch': epoch}

        torch.save(state,
                   backbone_path + "Backbone_IR_SE_50_Epoch_{}_Time_{}.pth".format(epoch + 1, get_time()))

    os.makedirs(head_path, exist_ok=True)
    if rank == 0:
        state = {'head': dist_sample_classifer.state_dict()}
        torch.save(state,
                   os.path.join(head_path, "Head_Arcface_Epoch_{}_Time_{}_head0.pth".format(epoch + 1, get_time())))
    if rank == 1:
        state = {'head': dist_sample_classifer.state_dict()}
        torch.save(state,
                   os.path.join(head_path, "Head_Arcface_Epoch_{}_Time_{}_head1.pth".format(epoch + 1, get_time())))
    if rank == 2:
        state = {'head': dist_sample_classifer.state_dict()}
        torch.save(state,
                   os.path.join(head_path, "Head_Arcface_Epoch_{}_Time_{}_head2.pth".format(epoch + 1, get_time())))
    if rank == 3:
        state = {'head': dist_sample_classifer.state_dict()}
        torch.save(state,
                   os.path.join(head_path, "Head_Arcface_Epoch_{}_Time_{}_head3.pth".format(epoch + 1, get_time())))

###dist.destroy_process_group()

if name == "main": parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') parser.add_argument('--local_rank', type=int, default=1, help='local_rank') args = parser.parse_args() os.environ['NCCL_DEBUG'] = 'INFO' world_size = int(os.environ.get('WORLD_SIZE', 1)) print("lixiang")

env = {
    "WORLD_SIZE": str(world_size),
    "LOCAL_WORLD_SIZE": str(world_size),
    "MASTER_ADDR": "127.0.0.1",
    "MASTER_PORT": str(find_free_port(8000, 8100)),
    "BAGUA_SERVICE_PORT": str(find_free_port(9000, 9100)),
}
print("lixiang2")
_init_bagua_env(args.local_rank, env)
rank = bagua.get_rank()
print(rank)

main(args.local_rank, rank, world_size, cfg)
lixiangMindSpore commented 2 years ago

This error is generally that the model is not on the specified GPU device. You can check whether the GPU where the model is located is equal to bagua.get_local_rank().

If there is still not work, please provide the minimal bug producing example script.

if I use DDP, it will be OK """ Author: {Yang Xiao, Xiang An, XuHan Zhu} in DeepGlint, Partial FC: Training 10 Million Identities on a Single Machine See the original paper: https://arxiv.org/abs/2010.05222 """ import os import argparse import time import math import torch import torch.distributed as dist import torch.nn.functional as F import torch.utils.data.distributed from torch import nn from torch.utils.tensorboard import SummaryWriter from backbones.model_irse import IR_SE_50, IR_SE_101 from config import config as cfg from utils import * from dataset import MXFaceDataset, DataLoaderX

from dataset import RecgDataset_mask as RecgDataset # 口罩增强~

from partial_classifier import DistSampleClassifier from partial_loss import MarginSoftmax from sgd import SGD from torchsummary import summary

torch.backends.cudnn.benchmark = True

def should_distribute(): return dist.is_available() and world_size >= 1

def is_distributed(): return dist.is_available() and dist.is_initialized()

def main(local_rank, rank, world_size, cfg):

dataloader

print('loading data...')
trainset = RecgDataset(cfg)

train_sampler = torch.utils.data.distributed.DistributedSampler(trainset, shuffle=True)
train_loader = DataLoaderX(local_rank=local_rank,
                           dataset=trainset,
                           batch_size=cfg.batch_size,
                           sampler=train_sampler,
                           num_workers=0,
                           pin_memory=True,
                           drop_last=True)

# model
print('loading model...')
backbone = IR_SE_50(cfg.input_size).to(local_rank)

backbone = IR_SE_101(cfg.input_size)

# Memory classifer
dist_sample_classifer = DistSampleClassifier(trainset.classes, rank=rank, local_rank=local_rank, world_size=world_size)
# Margin softmax
margin_softmax = MarginSoftmax(s=64.0, m=0.4)

# Optimizer for backbone and classifer
# optimizer = SGD([{'params': backbone.parameters()}, {'params': dist_sample_classifer.parameters()}],
#                lr=0.1, momentum=0.9, weight_decay=cfg.weight_decay, rescale=world_size)

# Broadcast init parameters
# for ps in backbone.parameters():
#    dist.broadcast(ps, 0)

backbone_path = os.path.join(cfg.model_save + 'backbone/')
head_path = os.path.join(cfg.model_save + 'head/')
log_path = os.path.join(cfg.log_save + 'shows/')

cfg.model_resume = cfg.model_save
backbone_resume = os.path.join(cfg.model_resume + 'backbone/')
head_resume = os.path.join(cfg.model_resume + 'head/')

# if cfg.resume and os.path.isdir(backbone_resume) and os.path.isdir(head_resume):
if cfg.resume and os.path.isdir(backbone_resume):
    print('resume~~~~~~~~~~~~~~~~~~~~~~~~~~')
    backbone_list = os.listdir(backbone_resume)
    if backbone_list:
        # pre_flags = [eval(x.split('Epoch_')[1].split('_Time')[0]) for x in backbone_list]
        # tar_flag = max(pre_flags)
        # print(tar_flag)
        # backbone_ckpt = torch.load(backbone_resume + '/' + backbone_list[0], map_location=torch.device('cpu'))
        #backbone_ckpt = torch.load(backbone_path + '/' + str(tar_flag) + '_backbone.pth')
        print(backbone_resume + '/' + backbone_list[0])
        backbone_ckpt = torch.load(backbone_resume + '/' + backbone_list[0], map_location=torch.device('cpu'))
        backbone.load_state_dict(backbone_ckpt['backbone'])
        print('load backbone ~')
        #optimizer.load_state_dict(backbone_ckpt['optimizer'])
        #print(optimizer.param_groups[0]['lr'])
        #start_epoch = backbone_ckpt['epoch'] + 1
        start_epoch = 1
        fg = 0
        if rank == 0 and fg:
            head_ckpt = torch.load(head_resume + '/' + 'Head_Arcface_Epoch_27_Time_2021-07-04-05-58' + '_head0.pth')
            dist_sample_classifer.load_state_dict(head_ckpt['head'])
        if rank == 1 and fg:
            head_ckpt = torch.load(head_resume + '/' + 'Head_Arcface_Epoch_27_Time_2021-07-04-05-58' + '_head1.pth')
            dist_sample_classifer.load_state_dict(head_ckpt['head'])
        if rank == 2 and fg:
            head_ckpt = torch.load(head_resume + '/' + 'Head_Arcface_Epoch_27_Time_2021-07-04-05-58' + '_head2.pth')
            dist_sample_classifer.load_state_dict(head_ckpt['head'])
        if rank == 3 and fg:
            head_ckpt = torch.load(head_resume + '/' + 'Head_Arcface_Epoch_27_Time_2021-07-04-05-58' + '_head3.pth')
            dist_sample_classifer.load_state_dict(head_ckpt['head'])

else:
    start_epoch = 0
    print("Train from Scratch")
print("=" * 60) 

backbone = backbone.to(local_rank)

optimizer = SGD([{'params': backbone.parameters()}, {'params': dist_sample_classifer.parameters()}],
                lr=cfg.lr, momentum=0.9, weight_decay=cfg.weight_decay, rescale=world_size)

# Broadcast init parameters
for ps in backbone.parameters():
    dist.broadcast(ps, 0)

DDP

backbone = torch.nn.SyncBatchNorm.convert_sync_batchnorm(backbone)

backbone = torch.nn.parallel.DistributedDataParallel(
    module=backbone, broadcast_buffers=False, device_ids=[rank])

# Lr scheduler
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                              lr_lambda=cfg.lr_func)
os.makedirs(log_path, exist_ok=True)
if local_rank == 0:
    writer = SummaryWriter(log_dir=log_path)

print('trainning...')
global_step = 0
n_epochs = cfg.num_epoch
# NUM_EPOCH_WARM_UP = n_epochs // 25
NUM_EPOCH_WARM_UP = 5
NUM_BATCH_WARM_UP = NUM_EPOCH_WARM_UP * len(train_loader)
backbone.train()
for epoch in range(start_epoch, n_epochs):
    train_sampler.set_epoch(epoch)
    print('lixiang0000000')
    print(len(train_loader))

    for step, (img, label) in enumerate(train_loader):
        print('lixiang1111111111111111')

        if (epoch + 1 <= NUM_EPOCH_WARM_UP) and (global_step + 1 <= NUM_BATCH_WARM_UP) and 1: # adjust LR for each training batch during warm up
            warm_up_lr(global_step + 1, NUM_BATCH_WARM_UP, cfg.lr, optimizer)

        total_label, norm_weight = dist_sample_classifer.prepare(label, optimizer)
        # print('total_label:', total_label.shape)
        # print('norm_weight:', norm_weight.shape)
        features = backbone(img)    # feature 在内部归一化了

        # Features all-gather
        total_features = torch.zeros(features.size()[0] * world_size, cfg.embedding_size, device=local_rank)
        dist.all_gather(list(total_features.chunk(world_size, dim=0)), features.data)
        total_features.requires_grad = True

        # Calculate logits
        # print('&' * 60)
        # print('total_features:', total_features.shape)
        # print('norm_weight:', norm_weight.shape)
        # print('total_label:', total_label.shape)
        logits = dist_sample_classifer(total_features, norm_weight)  # cos =
        # print('logits1:', logits.shape)
        # print('logits:', logits.shape)
        # print('&' * 60)
        logits = margin_softmax(logits, total_label)
        # print('logits2:', logits.shape)
        # total_logits = torch.zeros(logits.size()[0], len(DataLoaderX), device=local_rank)
        # dist.all_gather(list(total_logits.chunk(world_size, dim=0)),
        #                 logits.data)

        with torch.no_grad():
            max_fc = torch.max(logits, dim=1, keepdim=True)[0]
            #print('max_fc:', max_fc.shape)
            dist.all_reduce(max_fc, dist.ReduceOp.MAX)
            #print('#'*10, max_fc)
            # Calculate exp(logits) and all-reduce
            logits_exp = torch.exp(logits - max_fc)
            logits_sum_exp = logits_exp.sum(dim=1, keepdims=True)
            dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM)

            # Calculate prob
            logits_exp.div_(logits_sum_exp)

            # Get one-hot
            grad = logits_exp
            index = torch.where(total_label != -1)[0]
            one_hot = torch.zeros(index.size()[0], grad.size()[1], device=grad.device)
            one_hot.scatter_(1, total_label[index, None], 1)

            # Calculate loss
            loss = torch.zeros(grad.size()[0], 1, device=grad.device)
            loss[index] = grad[index].gather(1, total_label[index, None])
            dist.all_reduce(loss, dist.ReduceOp.SUM)
            loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1)

            # Calculate grad
            grad[index] -= one_hot
            grad.div_(features.size()[0])

        logits.backward(grad)
        if total_features.grad is not None:
            total_features.grad.detach_()
        x_grad = torch.zeros_like(features)

        # Feature gradient all-reduce
        dist.reduce_scatter(
            x_grad, list(total_features.grad.chunk(world_size, dim=0)))
        x_grad.mul_(world_size)
        # Backward backbone
        features.backward(x_grad)
        optimizer.step()

        # Update classifer
        dist_sample_classifer.update()
        optimizer.zero_grad()

        tm = time.asctime().split()[-2]
        print('neighbour:',rank)
        if rank == 0 and global_step % cfg.disp_freq == 0:
            writer.add_scalar('loss', loss_v, global_step)
            print('\nEpoch:{}/{} Batch:{}/{}\t'
                  'Loss:{loss:.4f}\t'
                  'lr: {lr:.4f}\t'
                  'TimeNow:{TimeNow}'.format(
                epoch, n_epochs, (step + 1) % len(train_loader), len(train_loader),
                loss=loss_v, lr=optimizer.param_groups[0]['lr'], TimeNow=tm))

        global_step += 1
    scheduler.step()

    if rank == 0:
        os.makedirs(backbone_path, exist_ok=True)
        state = {'backbone': backbone.module.state_dict(),
                 'optimizer': optimizer.state_dict(), 'epoch': epoch}

        torch.save(state,
                   backbone_path + "Backbone_IR_SE_50_Epoch_{}_Time_{}.pth".format(epoch + 1, get_time()))

    os.makedirs(head_path, exist_ok=True)
    if rank == 0:
        state = {'head': dist_sample_classifer.state_dict()}
        torch.save(state, os.path.join(head_path, "Head_Arcface_Epoch_{}_Time_{}_head0.pth".format(epoch+1, get_time())))
    if rank == 1:
        state = {'head': dist_sample_classifer.state_dict()}
        torch.save(state, os.path.join(head_path, "Head_Arcface_Epoch_{}_Time_{}_head1.pth".format(epoch+1, get_time())))
    if rank == 2:
        state = {'head': dist_sample_classifer.state_dict()}
        torch.save(state, os.path.join(head_path, "Head_Arcface_Epoch_{}_Time_{}_head2.pth".format(epoch+1, get_time())))
    if rank == 3:
        state = {'head': dist_sample_classifer.state_dict()}
        torch.save(state, os.path.join(head_path, "Head_Arcface_Epoch_{}_Time_{}_head3.pth".format(epoch+1, get_time())))

dist.destroy_process_group()

if name == "main": parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') parser.add_argument('--local_rank', type=int, default=1, help='local_rank') args = parser.parse_args() os.environ['NCCL_DEBUG'] = 'INFO' world_size = int(os.environ.get('WORLD_SIZE', 1)) if should_distribute(): print('Using distributed PyTorch with {} backend'.format(dist.Backend.NCCL)) dist.init_process_group(backend='nccl', init_method='env://', rank=args.local_rank, world_size=world_size) rank = dist.get_rank() else: rank = torch.cuda.device_count()

torch.cuda.set_device(args.local_rank)

main(args.local_rank, rank, world_size, cfg)
shjwudp commented 2 years ago

@lixiangMindSpore Bagua relaxes the restrictions on user scripting. We don't pass --local_rank to the user process, so args.local_rank won't work as you expect, you should use bagua.get_local_rank() replace it.

As I mentioned it in the previous comment.

NOBLES5E commented 2 years ago

Feel free to reopen the issue if it does not work as expected :)