Open AnirudhDagar opened 3 years ago
Can you fix it?
I benchmarked the training loop and tried to find the bottleneck.
Interestingly just by using the npx.multibox_target
method instead of the scratch implementation of d2l.multibox_target
resulted in the desired speedup. This indicates that the custom multibox_target function implemented from scratch for mxnet is the bottleneck for speed. Although the exact same implementation using custom scratch d2l.multibox_target
in PyTorch is just as fast.
Scratch Implementation using d2l.mutlibox_prior
slow in SSD Training
.
.
.
anchors, cls_preds, bbox_preds = net(X)
# Label the category and offset of each anchor box
bbox_labels, bbox_masks, cls_labels = d2l.multibox_target(anchors,
Y)
# Calculate the loss function using the predicted and labeled
# category and offset values
.
.
.
print(f'{len(train_iter._dataset) / timer.stop():.1f} examples/sec on '
f'{str(device)}')
>>> 2818.4 examples/sec on gpu
Scratch Implementation, but using npx.multibox_target is much faster in SSD Training
.
.
.
anchors, cls_preds, bbox_preds = net(X)
# Label the category and offset of each anchor box
bbox_labels, bbox_masks, cls_labels = npx.multibox_target(
anchors, Y, cls_preds.transpose(0, 2, 1))
# Calculate the loss function using the predicted and labeled
# category and offset values
.
.
.
print(f'{len(train_iter._dataset) / timer.stop():.1f} examples/sec on '
f'{str(device)}')
>>> 5244.6 examples/sec on gpu
These are timeit
results for the function multibox_target in pytorch vs mxnet:
%%timeit
#@tab pytorch
labels = multibox_target(anchors.unsqueeze(dim=0),
ground_truth.unsqueeze(dim=0))
>>> 651 µs ± 867 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%%timeit
labels = multibox_target(np.expand_dims(anchors, axis=0),
np.expand_dims(ground_truth, axis=0))
>>> 21.5 ms ± 24.3 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
This requires further investigation because even though the scratch function multibox_target
implements the same logic in different frameworks, mxnet is orders of magnitude slower than pytorch which is not expected. Either there is some room for speedup in mxnet implementation of the function or pytorch is just faster and in that case we can't really do much about this issue.
The scratch version of SSD is currently much slower than its counterpart PyTorch scratch implementation.