mit-han-lab / tinyml

MIT License
761 stars 139 forks source link

MCUNet_V2 imbalanced memory usage distribution #14

Closed QiYangLin closed 2 years ago

QiYangLin commented 2 years ago

Great work!

Could I get some detail of figure1(imbalanced memory usage distribution) in MCUNetV2, like which hardware and compiler are used for the base case? Model inference lots of data and then average them or just inference one data point. How to measure and compute a diagram. If there is a code for measurement would be fantastic.

tonylins commented 2 years ago

Hi, thanks for your interest in our work. For Figure 1, we are actually analyzing MobileNetV2. We used analytic profiling, which is hardware and compiler agnostic. You can check the first paragraph of the Experiment section for details. But in short, such profiling counts the input and output activation size (for multi-branch structure, you need to consider all the concurrent memory usage).

We calculate the memory usage for batch size=1. The memory usage is the same for the same input size, so we do not need to average.

QiYangLin commented 2 years ago

Thanks for answering my question!

The peak memory usage is the maximum feature map size in the block. Am I correct? Could you demonstrate the calculation for peak memory usage in MobileNetV2 since I can't get the same result as figure 1? Additionally, there are only 17 bottlenecks in MobileNetV2. Does Block0 indicate the first regular convolution?

tonylins commented 2 years ago

Hi, sorry for the delay. Yes, Block 0 indicates the first convolution.

Here is a code piece to reproduce the results in Figure 1:

import torch
import torch.nn as nn
import numpy as np
import torchvision

DATA_TYPE = 'int8'
INPLACE_DW = True

assert DATA_TYPE in ['int8']

def record_in_out_size(m, x, y):
    x = x[0]
    m.input_shape = list(x.shape)
    m.output_shape = list(y.shape)
    # we discard the batch dimension for now
    m.input_size = np.prod(m.input_shape[1:]).item()
    m.output_size = np.prod(m.output_shape[1:]).item()

def add_io_hooks(m_):
    # let's register the hook for all
    m_.register_forward_hook(record_in_out_size)

def estimate_conv(m):
    assert isinstance(m, nn.Conv2d), type(m)
    if INPLACE_DW and m.in_channels == m.out_channels == m.groups:  # inplace depthwise from MCUNetV1
        return m.input_size + np.prod(m.output_shape[2:])  # add a buffer with size equal to one channel
    else:  # normal conv
        return m.input_size + m.output_size

def estimate_mbconv(m):
    assert len(m.conv) in [3, 4]
    if DATA_TYPE == 'int8':  # normal implementation for int8 quantized case
        if m.use_res_connect:
            if len(m.conv) == 3:
                return max([
                    estimate_conv(m.conv[0][0]),
                    estimate_conv(m.conv[1]),
                ])
            else:
                return max([
                    estimate_conv(m.conv[0][0]),
                    estimate_conv(m.conv[1][0]) + m.input_size,
                    estimate_conv(m.conv[2])
                ])
        else:
            if len(m.conv) == 3:
                return max([
                    estimate_conv(m.conv[0][0]),
                    estimate_conv(m.conv[1])
                ])
            else:
                return max([
                    estimate_conv(m.conv[0][0]),
                    estimate_conv(m.conv[1][0]),
                    estimate_conv(m.conv[2])
                ])
    else:
        raise NotImplementedError

def nelem2kb(n):
    # converting number of elements to kB according to the precision
    if DATA_TYPE == 'fp':
        n = n * 4  # 4 bytes per floating point
    return n / 1024

def estimate_model(model, x):
    model.eval()
    model.apply(add_io_hooks)

    with torch.no_grad():
        _ = model(x)

    memory_usage = list()

    # now estimate the memory of the backbone
    # first conv
    memory_usage.append(estimate_conv(model.features[0][0]))
    # blocks
    for b in model.features[1:-1]:
        memory_usage.append(estimate_mbconv(b))

    # now convert everything into kB
    memory_usage = [nelem2kb(m) for m in memory_usage]

    print('memory usage in kB:', [round(m, 1) for m in memory_usage])  # kB

    import matplotlib.pyplot as plt
    plt.bar(range(len(memory_usage)), memory_usage)
    plt.show()

if __name__ == '__main__':
    net = torchvision.models.mobilenet_v2()
    sample_input = torch.randn(1, 3, 224, 224)
    estimate_model(net, sample_input)
QiYangLin commented 2 years ago

Thanks for sharing.

One more question. For the use_res_connect case in function estimate_mbconv, the residual connection would cross three convolutions.

From my understanding, it will be like the code below. Please correct me if I misunderstand. Thank you.

return max([ estimate_conv(m.conv[0][0]) + m.input_size, estimate_conv(m.conv[1][0]) + m.input_size, estimate_conv(m.conv[2])+ m.input_size ])

tonylins commented 2 years ago

Hi, the first conv is sharing the same input as the residual connection, so we do not need to count the input activation twice. For the last conv, it shares the same output as the residual connection.

We are currently cleaning and will release some MCUNet-v2 related code here: https://github.com/mit-han-lab/mcunet.