Open eric-haibin-lin opened 5 years ago
Hey, this is the MXNet Label Bot. Thank you for submitting the issue! I will try and suggest some labels so that the appropriate MXNet community members can help resolve it. Here are my recommended labels: Feature
Curious to know why we need an estimator. Can a estimated number stop the users from trying a larger input size? I think they would not stop until they see the OOM error. Compared with memory usage estimator, I think it would be better if MXNet can have a memory profiler.
@TaoLv This feature is more geared towards federated training, when some of the model shapes are inferred from the user-provided dataset. In this case, trying to utilize memory more efficiently, we have to understand how to set other parameters, e.g., batch_size, accordingly. Essentially, we aim to "try a larger input size" inside a script.
An example would be to fit word_language_model with different vocabulary sizes, where filtering by word frequency is not an option.
Additionally, it would be useful to consider memory planning for sparse data. In this case, we would provide the 'max_num_numzeros' parameter to reduce the uncertainties.
GPU memory is limited. It would be great to have an utility function to estimate the memory usage of training a model given the shape of the input. Currently, the only way is to run the model with different hidden_size and batch_size (trial and error). MXNet could provide an API that makes this process automatic so that it's easier for the user. @Roshrini @pinaraws @cgraywang @yifeim feel free to add additional contexts
For example, there is a 3rdparty library that prints the theoretical memory consumption of a pytorch model's forward/backward intermediate data entries, weights and gradients. https://github.com/jacobkimmel/pytorch_modelsize
In MXNet we can record the memory planning of a model and report the memory usage given input shapes. This does not include the temporary memory requested at runtime (e.g. by MKLDNN/CUDNN). If reporting planned memory usage not accurate enough, we can simply run the trials and return the actual peak memory usage at runtime.