Closed QiYangLin closed 2 years ago
Hi, thanks for your interest in our work. For Figure 1, we are actually analyzing MobileNetV2. We used analytic profiling, which is hardware and compiler agnostic. You can check the first paragraph of the Experiment section for details. But in short, such profiling counts the input and output activation size (for multi-branch structure, you need to consider all the concurrent memory usage).
We calculate the memory usage for batch size=1. The memory usage is the same for the same input size, so we do not need to average.
Thanks for answering my question!
The peak memory usage is the maximum feature map size in the block. Am I correct? Could you demonstrate the calculation for peak memory usage in MobileNetV2 since I can't get the same result as figure 1? Additionally, there are only 17 bottlenecks in MobileNetV2. Does Block0 indicate the first regular convolution?
Hi, sorry for the delay. Yes, Block 0 indicates the first convolution.
Here is a code piece to reproduce the results in Figure 1:
import torch
import torch.nn as nn
import numpy as np
import torchvision
DATA_TYPE = 'int8'
INPLACE_DW = True
assert DATA_TYPE in ['int8']
def record_in_out_size(m, x, y):
x = x[0]
m.input_shape = list(x.shape)
m.output_shape = list(y.shape)
# we discard the batch dimension for now
m.input_size = np.prod(m.input_shape[1:]).item()
m.output_size = np.prod(m.output_shape[1:]).item()
def add_io_hooks(m_):
# let's register the hook for all
m_.register_forward_hook(record_in_out_size)
def estimate_conv(m):
assert isinstance(m, nn.Conv2d), type(m)
if INPLACE_DW and m.in_channels == m.out_channels == m.groups: # inplace depthwise from MCUNetV1
return m.input_size + np.prod(m.output_shape[2:]) # add a buffer with size equal to one channel
else: # normal conv
return m.input_size + m.output_size
def estimate_mbconv(m):
assert len(m.conv) in [3, 4]
if DATA_TYPE == 'int8': # normal implementation for int8 quantized case
if m.use_res_connect:
if len(m.conv) == 3:
return max([
estimate_conv(m.conv[0][0]),
estimate_conv(m.conv[1]),
])
else:
return max([
estimate_conv(m.conv[0][0]),
estimate_conv(m.conv[1][0]) + m.input_size,
estimate_conv(m.conv[2])
])
else:
if len(m.conv) == 3:
return max([
estimate_conv(m.conv[0][0]),
estimate_conv(m.conv[1])
])
else:
return max([
estimate_conv(m.conv[0][0]),
estimate_conv(m.conv[1][0]),
estimate_conv(m.conv[2])
])
else:
raise NotImplementedError
def nelem2kb(n):
# converting number of elements to kB according to the precision
if DATA_TYPE == 'fp':
n = n * 4 # 4 bytes per floating point
return n / 1024
def estimate_model(model, x):
model.eval()
model.apply(add_io_hooks)
with torch.no_grad():
_ = model(x)
memory_usage = list()
# now estimate the memory of the backbone
# first conv
memory_usage.append(estimate_conv(model.features[0][0]))
# blocks
for b in model.features[1:-1]:
memory_usage.append(estimate_mbconv(b))
# now convert everything into kB
memory_usage = [nelem2kb(m) for m in memory_usage]
print('memory usage in kB:', [round(m, 1) for m in memory_usage]) # kB
import matplotlib.pyplot as plt
plt.bar(range(len(memory_usage)), memory_usage)
plt.show()
if __name__ == '__main__':
net = torchvision.models.mobilenet_v2()
sample_input = torch.randn(1, 3, 224, 224)
estimate_model(net, sample_input)
Thanks for sharing.
One more question. For the use_res_connect case in function estimate_mbconv, the residual connection would cross three convolutions.
From my understanding, it will be like the code below. Please correct me if I misunderstand. Thank you.
return max([ estimate_conv(m.conv[0][0]) + m.input_size, estimate_conv(m.conv[1][0]) + m.input_size, estimate_conv(m.conv[2])+ m.input_size ])
Hi, the first conv is sharing the same input as the residual connection, so we do not need to count the input activation twice. For the last conv, it shares the same output as the residual connection.
We are currently cleaning and will release some MCUNet-v2 related code here: https://github.com/mit-han-lab/mcunet.
Great work!
Could I get some detail of figure1(imbalanced memory usage distribution) in MCUNetV2, like which hardware and compiler are used for the base case? Model inference lots of data and then average them or just inference one data point. How to measure and compute a diagram. If there is a code for measurement would be fantastic.