mlcommons / algorithmic-efficiency

MLCommons Algorithmic Efficiency is a benchmark and competition measuring neural network training speedups due to algorithmic improvements in both training algorithms and models.
https://mlcommons.org/en/groups/research-algorithms/
Apache License 2.0
319 stars 60 forks source link

Resnet DDP Warning: grad strides do not match bucket view strides #741

Open Niccolo-Ajroldi opened 4 months ago

Niccolo-Ajroldi commented 4 months ago

Description

On imagenet_resnet workload, I encounter the following warning when running with DDP and pytorch framework.

/u/najroldi/miniconda3/envs/alpe/lib/python3.8/site-packages/torch/autograd/init.py:251: UserWarning: Grad strides do not match bucket view strides. This may indicate grad was not created according to the gradient layout contract, or that the param's strides changed since DDP was constructed. This is not an error, but may impair performance. grad.sizes() = [512, 2048, 1, 1], strides() = [2048, 1, 2048, 2048]

Is this a known problem? Is there a known fix?

Related to this, I have noticed that the prize_qualification_baseline logs start with python3 submission_runner.py. For self reporting results, and for the final scoring, should we use DDP?

Steps to Reproduce

torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submisison_runner.py --framework=pytorch --workload=imagenet_resnet ...
priyakasimbeg commented 3 months ago

Niccolo sorry for the delay. It seems like this issue fell through the cracks. I noticed this too on our baselines, will investigate further.

The prize qualification logs are from JAX runs, which is why they start with python3 submission_runner.py. Please use DDP for self reporting runs.

priyakasimbeg commented 3 months ago

@msaroufim have you encountered this warning before? Full message in this log. So far it seems like it happens only on ResNet. I just noticed this warning has been in our test logs for >3 months, I think all of which we were running on PyTorch 2.1. Trying to find the first instance this warning occurred..

priyakasimbeg commented 3 months ago

It looks like this warning only gets printed in the first step. I'm not sure how it affects the speed of the rest of the training. I had last compared the PyTorch and JAX run times in Dec 2023. The ResNet runs in PyTorch also show this warning in the logs for the timing runs. ResNet is about 6% slower in PyTorch vs JAX which I think is within tolerance.