d2l-ai / d2l-en

Interactive deep learning book with multi-framework code, math, and discussions. Adopted at 500 universities from 70 countries including Stanford, MIT, Harvard, and Cambridge.
https://D2L.ai
Other
24.12k stars 4.38k forks source link

[MXNet] Speedup SSD Scratch Implementation #1582

Open AnirudhDagar opened 3 years ago

AnirudhDagar commented 3 years ago

The scratch version of SSD is currently much slower than its counterpart PyTorch scratch implementation.

astonzhang commented 3 years ago

Can you fix it?

AnirudhDagar commented 3 years ago

I benchmarked the training loop and tried to find the bottleneck. Interestingly just by using the npx.multibox_target method instead of the scratch implementation of d2l.multibox_target resulted in the desired speedup. This indicates that the custom multibox_target function implemented from scratch for mxnet is the bottleneck for speed. Although the exact same implementation using custom scratch d2l.multibox_target in PyTorch is just as fast.

Scratch Implementation using d2l.mutlibox_prior slow in SSD Training

.
.
.
            anchors, cls_preds, bbox_preds = net(X)
            # Label the category and offset of each anchor box
            bbox_labels, bbox_masks, cls_labels = d2l.multibox_target(anchors,
                                                                      Y)
            # Calculate the loss function using the predicted and labeled
            # category and offset values
.
.
.
print(f'{len(train_iter._dataset) / timer.stop():.1f} examples/sec on '
      f'{str(device)}')

>>> 2818.4 examples/sec on gpu

Scratch Implementation, but using npx.multibox_target is much faster in SSD Training

.
.
.
            anchors, cls_preds, bbox_preds = net(X)
            # Label the category and offset of each anchor box
            bbox_labels, bbox_masks, cls_labels = npx.multibox_target(
                anchors, Y, cls_preds.transpose(0, 2, 1))
            # Calculate the loss function using the predicted and labeled
            # category and offset values
.
.
.
print(f'{len(train_iter._dataset) / timer.stop():.1f} examples/sec on '
      f'{str(device)}')

>>> 5244.6 examples/sec on gpu

These are timeit results for the function multibox_target in pytorch vs mxnet:

%%timeit
#@tab pytorch
labels = multibox_target(anchors.unsqueeze(dim=0),
                         ground_truth.unsqueeze(dim=0))

>>> 651 µs ± 867 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%%timeit
labels = multibox_target(np.expand_dims(anchors, axis=0),
                         np.expand_dims(ground_truth, axis=0))

>>> 21.5 ms ± 24.3 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

This requires further investigation because even though the scratch function multibox_target implements the same logic in different frameworks, mxnet is orders of magnitude slower than pytorch which is not expected. Either there is some room for speedup in mxnet implementation of the function or pytorch is just faster and in that case we can't really do much about this issue.