facebookresearch / kill-the-bits

Code for: "And the bit goes down: Revisiting the quantization of neural networks"
Other
636 stars 123 forks source link

About multi-gpu training #25

Closed Fight-hawk closed 5 years ago

Fight-hawk commented 5 years ago

Hi, i would like to ask you some details about multi-gpu training, i just follow your reply. But i got dead lock . I can not find the reason.

pierrestock commented 5 years ago

Hey Fight-hawk,

Sorry to hear that! Deadlocks usually happen when one GPU is waiting for the others, typically when torch.distributed.broadcast() function has not been called for every GPU.

Please share relevant parts of your code in the answer if this does not work!

Pierre

Fight-hawk commented 5 years ago

`def main():

get arguments

global args
args = parser.parse_args()
args.block = '' if args.block == 'all' else args.block

# 分布式训练初始化
torch.distributed.init_process_group(backend="nccl")

# 配置每个进程的gpu
torch.cuda.set_device(args.local_rank)
device = torch.device('cuda', args.local_rank)

# student model to quantize
student = models.__dict__[args.model](pretrained=True)
student = student.to(device)
torch.nn.parallel.DistributedDataParallel(student,
                                          device_ids=[args.local_rank],
                                          output_device=args.local_rank)
student.eval()
criterion = nn.CrossEntropyLoss().cuda(device=device)
# print("Process: {:d}  criterion's device: {}".format(os.getpid(), criterion))
cudnn.benchmark = True

# layers to quantize (we do not quantize the first 7x7 convolution layer)
watcher = ActivationWatcher(student)
layers = [layer for layer in watcher.layers[1:] if args.block in layer]

# data loading code
# train_loader, test_loader = load_data(data_path=args.data_path, batch_size=args.batch_size, nb_workers=args.n_workers)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

transf_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize])
transf_test = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize
])

trainset = torchvision.datasets.CIFAR10(root='../data', train=True,
                                        download=True, transform=transf_train)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
                                           num_workers=0, sampler=DistributedSampler(trainset))

testset = torchvision.datasets.CIFAR10(root='../data', train=False,
                                       download=True, transform=transf_test)
test_loader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size,
                                          num_workers=0, sampler=DistributedSampler(testset))
# parameters for the centroids optimizer
opt_centroids_params_all = []

# book-keeping for compression statistics (in MB)
size_uncompressed = compute_size(student)
size_index = 0
size_centroids = 0
size_other = size_uncompressed

# teacher model
teacher = models.__dict__[args.model](pretrained=True)
teacher.to(device)
teacher = torch.nn.parallel.DistributedDataParallel(teacher,
                                                    device_ids=[args.local_rank],
                                                    output_device=args.local_rank)
teacher.eval()

# Step 1: iteratively quantize the network layers (quantization + layer-wise centroids distillation)
print('Process: {:d}  Step 1: Quantize network'.format(os.getpid()))
t = time.time()
top_1 = 0

for layer in layers:
    #  gather input activations
    n_iter_activations = math.ceil(args.n_activations / args.batch_size)
    watcher = ActivationWatcher(student, layer=layer)
    in_activations_current = watcher.watch(train_loader, criterion, n_iter_activations)
    in_activations_current = in_activations_current[layer]

    # get weight matrix and detach it from the computation graph (.data should be enough, adding .detach() as a safeguard)
    M = attrgetter(layer + '.weight.data')(student).detach()
    sizes = M.size()
    is_conv = len(sizes) == 4

    # get padding and stride attributes
    padding = attrgetter(layer)(student).padding if is_conv else 0
    stride = attrgetter(layer)(student).stride if is_conv else 1
    groups = attrgetter(layer)(student).groups if is_conv else 1

    # block size, distinguish between fully connected and convolutional case
    if is_conv:
        out_features, in_features, k, _ = sizes
        block_size = args.block_size_cv if k > 1 else args.block_size_pw
        n_centroids = args.n_centroids_cv if k > 1 else args.n_centroids_pw
        n_blocks = in_features * k * k // block_size
    else:
        k = 1
        out_features, in_features = sizes
        block_size = args.block_size_fc
        n_centroids = args.n_centroids_fc
        n_blocks = in_features // block_size

    # clamp number of centroids for stability
    powers = 2 ** np.arange(0, 16, 1)
    n_vectors = np.prod(sizes) / block_size
    idx_power = bisect_left(powers, n_vectors / args.n_centroids_threshold)
    n_centroids = min(n_centroids, powers[idx_power - 1])

    # compression rations
    bits_per_weight = np.log2(n_centroids) / block_size

    # number of bits per weight
    size_index_layer = bits_per_weight * M.numel() / 8 / 1024 / 1024
    size_index += size_index_layer

    # centroids stored in float16
    size_centroids_layer = n_centroids * block_size * 2 / 1024 / 1024
    size_centroids += size_centroids_layer

    # size of non-compressed layers, e.g. BatchNorms or first 7x7 convolution
    size_uncompressed_layer = M.numel() * 4 / 1024 / 1024
    size_other -= size_uncompressed_layer

    # number of samples
    n_samples = dynamic_sampling(layer)

    # print layer size
    print('Process: {:d}  Quantizing layer: {}, size: {}, n_blocks: {}, block size: {}, ' \
          'centroids: {}, bits/weight: {:.2f}, compressed size: {:.2f} MB'.format(os.getpid(),
           layer, list(sizes), n_blocks, block_size, n_centroids,
           bits_per_weight, size_index_layer + size_centroids_layer))

    # quantizer
    quantizer = PQ(in_activations_current, M, n_activations=args.n_activations,
                   n_samples=n_samples, eps=args.eps, n_centroids=n_centroids,
                   n_iter=args.n_iter, n_blocks=n_blocks, k=k,
                   stride=stride, padding=padding, groups=groups, device=device)
    assignments = torch.zeros(out_features).to(device)
    if args.local_rank == 0:
        if len(args.restart) > 0:
            # do not quantize already quantized layers
            try:
                # load centroids and assignments if already stored
                quantizer.load(args.restart, layer)
                centroids = quantizer.centroids
                assignments = quantizer.assignments

                # quantize weight matrix
                M_hat = weight_from_centroids(centroids, assignments, n_blocks, k, is_conv)
                attrgetter(layer + '.weight')(student).data = M_hat
                quantizer.save(args.save, layer)

                # optimizer for global finetuning
                parameters = [p for (n, p) in student.named_parameters() if layer in n and 'bias' not in n]
                centroids_params = {'params': parameters,
                                    'assignments': assignments,
                                    'kernel_size': k,
                                    'n_centroids': n_centroids,
                                    'n_blocks': n_blocks}
                opt_centroids_params_all.append(centroids_params)

                # proceed to next layer
                print('Process: {:d}  Layer already quantized, proceeding to next layer\n'.format(os.getpid()))
                continue

            # otherwise, quantize layer
            except FileNotFoundError:
                print('Process: {:d}  Quantizing layer'.format(os.getpid()))

        # quantize layer
        quantizer.encode()

        # assign quantized weight matrix
        M_hat = quantizer.decode()
        attrgetter(layer + '.weight')(student).data = M_hat

        # top1
        top_1 = evaluate(test_loader, student, criterion).item()
        # book-keeping
        print('Process: {:d}  Quantizing time: {:.0f}min, Top1 after quantization: {:.2f}\n'.format(os.getpid(), (time.time() - t) / 60, top_1))
        assignments = quantizer.assignments
    torch.distributed.barrier()
    torch.distributed.broadcast(attrgetter(layer + '.weight')(student).data.contiguous(), 0)
    torch.distributed.broadcast(assignments.contiguous(), 0)

    t = time.time()
    # Step 2: finetune centroids
    print('Process: {:d}  Finetuning centroids'.format(os.getpid()))

    # optimizer for centroids
    parameters = [p for (n, p) in student.named_parameters() if layer in n and 'bias' not in n]
    print('Process: {:d}  line: {:d}'.format(os.getpid(), 296))
    centroids_params = {'params': parameters,
                        'assignments': assignments,
                        'kernel_size': k,
                        'n_centroids': n_centroids,
                        'n_blocks': n_blocks}

    # remember centroids parameters to finetuning at the end
    opt_centroids_params = [centroids_params]
    opt_centroids_params_all.append(centroids_params)
    print('Process: {:d}  line: {:d}'.format(os.getpid(), 306))
    # custom optimizer
    optimizer_centroids = CentroidSGD(opt_centroids_params, lr=args.lr_centroids,
                                      momentum=args.momentum_centroids,
                                      weight_decay=args.weight_decay_centroids)

    # standard training loop
    n_iter = args.finetune_centroids
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer_centroids, step_size=1, gamma=0.1)
    print('Process: {:d}  line: {:d}'.format(os.getpid(), 315))
    for epoch in range(1):
        print('Process: {:d}  line: {:d}'.format(os.getpid(), 317))
        finetune_centroids(train_loader, student, teacher, criterion, optimizer_centroids, n_iter=n_iter)
        print('Process: {:d}  line: {:d}'.format(os.getpid(), 319))
        top_1 = evaluate(test_loader, student, criterion)
        scheduler.step()
        print('Process: {:d}  Epoch: {}, Top1: {:.2f}'.format(os.getpid(), epoch, top_1))

    print('Process: {:d}  After {} iterations with learning rate {}, Top1: {:.2f}'.format(os.getpid(), n_iter, args.lr_centroids, top_1))

    # book-keeping
    print('Process: {:d}  Finetuning centroids time: {:.0f}min, Top1 after finetuning centroids: {:.2f}\n'.format(os.getpid(), (time.time() - t) / 60, top_1))
    t = time.time()

    # saving
    M_hat = attrgetter(layer + '.weight')(student).data
    centroids = centroids_from_weights(M_hat, assignments, n_centroids, n_blocks)
    quantizer.centroids = centroids
    quantizer.save(args.save, layer)

# End of compression + finetuning of centroids
size_compressed = size_index + size_centroids + size_other
print('Process: {:d}  End of compression, non-compressed teacher model: {:.2f}MB, compressed student model ' \
      '(indexing + centroids + other): {:.2f}MB + {:.2f}MB + {:.2f}MB = {:.2f}MB, compression ratio: {:.2f}x\n'.format(
      os.getpid(), size_uncompressed, size_index, size_centroids, size_other, size_compressed, size_uncompressed / size_compressed))

# Step 3: finetune whole network
print('Process: {:d}  Step 3: Finetune whole network'.format(os.getpid()))
t = time.time()

# custom optimizer
optimizer_centroids_all = CentroidSGD(opt_centroids_params_all, lr=args.lr_whole,
                                  momentum=args.momentum_whole,
                                  weight_decay=args.weight_decay_whole)

# standard training loop
n_iter = args.finetune_whole
scheduler = torch.optim.lr_scheduler.StepLR(optimizer_centroids_all, step_size=args.finetune_whole_step_size, gamma=0.1)
for epoch in range(args.finetune_whole_epochs):
    student.train()
    finetune_centroids(train_loader, student, teacher, criterion, optimizer_centroids_all, n_iter=n_iter)
    top_1 = evaluate(test_loader, student, criterion)
    scheduler.step()
    print('Process: {:d}  Epoch: {}, Top1: {:.2f}'.format(os.getpid(), epoch, top_1))

# state dict pf compressed model
state_dict_compressed = {}

# save conv1 (not quantized)
state_dict_compressed['conv1'] = student.conv1.state_dict()

# save biases of the classifier
state_dict_compressed['fc_bias'] = {'bias': student.fc.bias}

# save batch norms
bn_layers = watcher._get_bn_layers()

for bn_layer in bn_layers:
    state_dict_compressed[bn_layer] = attrgetter(bn_layer)(student).state_dict()

# save quantized layers
for layer in layers:

    # stats
    M = attrgetter(layer + '.weight.data')(student).detach()
    sizes = M.size()
    is_conv = len(sizes) == 4

    # get padding and stride attributes
    padding = attrgetter(layer)(student).padding if is_conv else 0
    stride = attrgetter(layer)(student).stride if is_conv else 1
    groups = attrgetter(layer)(student).groups if is_conv else 1

    # block size, distinguish between fully connected and convolutional case
    if is_conv:
        out_features, in_features, k, _ = sizes
        block_size = args.block_size_cv if k > 1 else args.block_size_pw
        n_centroids = args.n_centroids_cv
        n_blocks = in_features * k * k // block_size
    else:
        k = 1
        out_features, in_features = sizes
        block_size = args.block_size_fc
        n_centroids = args.n_centroids_fc
        n_blocks = in_features // block_size

    # clamp number of centroids for stability
    powers = 2 ** np.arange(0, 16, 1)
    n_vectors = np.prod(sizes) / block_size
    idx_power = bisect_left(powers, n_vectors / args.n_centroids_threshold)
    n_centroids = min(n_centroids, powers[idx_power - 1])

    # save
    quantizer.load(args.save, layer)
    assignments = quantizer.assignments
    M_hat = attrgetter(layer + '.weight')(student).data
    centroids = centroids_from_weights(M_hat, assignments, n_centroids, n_blocks)
    quantizer.centroids = centroids
    quantizer.save(args.save, layer)
    state_dict_layer = {
        'centroids': centroids.half(),
        'assignments': assignments.short() if 'fc' in layer else assignments.byte(),
        'n_blocks': n_blocks,
        'is_conv': is_conv,
        'k': k
    }
    state_dict_compressed[layer] = state_dict_layer

# save model
torch.save(state_dict_compressed, os.path.join(args.save, 'state_dict_compressed.pth'))

# book-keeping
print('Process: {:d}  Finetuning whole network time: {:.0f}min, Top1 after finetuning centroids: {:.2f}\n'.format(os.getpid(), (time.time() - t) / 60, top_1))`

thank you for your reply, this is the main function

Fight-hawk commented 5 years ago

i found that each process hang in input = input.to(device=device) target = target.to(device=device) in function finetune_centroids in train.py

pierrestock commented 5 years ago

Hi again,

I think the line torch.distributed.broadcast(assignments.contiguous(), 0) is causing the deadlock. Indeed, you can only broadcast a variable that exists in all the GPUs. What you can do is

Could you try this?

Fight-hawk commented 5 years ago

thank you very much, i will try. but i can reach the code after that line

pierrestock commented 5 years ago

Oh sorry I didn't see that you already did this. OOMs errors are also hard to catch in distributed mode, did you try reducing the batch size?

Please re-open the issue is you are still in trouble.