Closed hjmshi closed 3 months ago
Hi Michael, I will check if I can reproduce this on my end on our hardware.
Confirming that I can reproduce this on our hardware. I just realized we encountered the same issue in JAX several weeks ago o https://github.com/mlcommons/algorithmic-efficiency/issues/644. To resolve it we had updated the target setting configs to use bsz=512 for these variants. We should have also updated the baseline algorithm modules to override the bsz for these resnet variants, but it seems like that slipped through the cracks. So the solution here is to reduce the bsz on these variants for your submission. I will also modify the baseline configs to reflect the reduced bsz for these variants.
When running the workload variants for ImageNet ResNet with activation function changes (SiLU/GELU) in PyTorch, we find that the baseline NAdamW method OOMs on AWS with competition hardware. This is with the default batch size of 1024.
Description
We observe OOM errors such as:
ImageNet ResNet GELU:
ImageNet ResNet SiLU:
We have also tried disabling
torch.compile
but still observe the OOM issue.cc @tsunghsienlee @anana10c @mikerabbat @shintaro-iwasaki @yuchenhao
Steps to Reproduce
To reproduce, you can run:
and similar with
--workload=imagenet_resnet_silu
.Source or Possible Fix
We do not currently have a fix aside from decreasing the batch size for our submission, but we wanted to raise this issue as it also impacts the baseline methods. Please advise on how to proceed. Thanks!