mlcommons / algorithmic-efficiency

MLCommons Algorithmic Efficiency is a benchmark and competition measuring neural network training speedups due to algorithmic improvements in both training algorithms and models.
https://mlcommons.org/en/groups/research-algorithms/
Apache License 2.0
319 stars 60 forks source link

OOM with ImageNet ResNet SiLU/GELU Workloads in PyTorch #744

Closed hjmshi closed 3 months ago

hjmshi commented 4 months ago

When running the workload variants for ImageNet ResNet with activation function changes (SiLU/GELU) in PyTorch, we find that the baseline NAdamW method OOMs on AWS with competition hardware. This is with the default batch size of 1024.

Description

We observe OOM errors such as:

ImageNet ResNet GELU:

File "/usr/local/lib/python3.8/dist-packages/torch/_inductor/codecache.py", line 374, in __call__    
    return self.get_current_callable()(inputs)                                                                                                                                                                 
  File "/usr/local/lib/python3.8/dist-packages/torch/_inductor/compile_fx.py", line 628, in run        
    return model(new_inputs)                                                                                                                                                                                   
  File "/usr/local/lib/python3.8/dist-packages/torch/_inductor/codecache.py", line 401, in _run_from_cache                                                                                                     
    return compiled_graph.compiled_artifact(inputs)                                                    
  File "/tmp/torchinductor_root/2i/c2i5knmigfe2haqboyqnscdhekkkt3qk75trzvwq4z23mpdj26f2.py", line 1587, in call                                                                                                
    buf245 = empty_strided((128, 1024, 14, 14), (200704, 1, 14336, 1024), device='cuda', dtype=torch.float32)                                                                                                  
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 98.00 MiB. GPU 5 has a total capacty of 15.77 GiB of which 56.56 MiB is free. Process 2502141 has 15.71 GiB memory in use. Of the allocated 
memory 14.68 GiB is allocated by PyTorch, and 74.72 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See document
ation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

ImageNet ResNet SiLU:

File "/usr/local/lib/python3.8/dist-packages/torch/_inductor/triton_heuristics.py", line 282, in bench                                                                                                       
    return do_bench(kernel_call, rep=40, fast_flush=True)                                                                                                                                                      
  File "/usr/local/lib/python3.8/dist-packages/torch/_inductor/utils.py", line 75, in do_bench                                                                                                                 
    return triton_do_bench(*args, **kwargs)[0]                                                         
  File "/usr/local/lib/python3.8/dist-packages/triton/testing.py", line 111, in do_bench                                                                                                                       
    cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')                                                                                                                                       
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 246.00 MiB. GPU 3 has a total capacty of 15.77 GiB of which 32.56 MiB is free. Process 2502139 has 15.73 GiB memory in use. Of the allocated
 memory 14.77 GiB is allocated by PyTorch, and 94.22 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documen
tation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

We have also tried disabling torch.compile but still observe the OOM issue.

cc @tsunghsienlee @anana10c @mikerabbat @shintaro-iwasaki @yuchenhao

Steps to Reproduce

To reproduce, you can run:

 torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 \                                                                                                                                                            
     --standalone \                                                                                                                                                                                            
     --nnodes=1 \                                                                                                                                                                                              
     --nproc_per_node=8 \                                                                                                                                                                                      
     submission_runner.py \                                                                                                                                                                                    
     --framework=pytorch \                                                                                                                                                                                     
     --data_dir=/data/imagenet/pytorch/ \                                                                                                                                                                      
     --workload=imagenet_resnet_gelu \                                                                                                                                                                         
     --experiment_dir=/experiment_runs \                                                                                                                                                                       
     --experiment_name=nadamw_baseline \                                                                                                                                                                       
     --submission_path=prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py \                                                                                                           
     --tuning_search_space=prize_qualification_baselines/external_tuning/tuning_search_space.json \                                                                                                           
     --imagenet_v2_data_dir=/data/imagenet/         

and similar with --workload=imagenet_resnet_silu.

Source or Possible Fix

We do not currently have a fix aside from decreasing the batch size for our submission, but we wanted to raise this issue as it also impacts the baseline methods. Please advise on how to proceed. Thanks!

priyakasimbeg commented 4 months ago

Hi Michael, I will check if I can reproduce this on my end on our hardware.

priyakasimbeg commented 4 months ago

Confirming that I can reproduce this on our hardware. I just realized we encountered the same issue in JAX several weeks ago o https://github.com/mlcommons/algorithmic-efficiency/issues/644. To resolve it we had updated the target setting configs to use bsz=512 for these variants. We should have also updated the baseline algorithm modules to override the bsz for these resnet variants, but it seems like that slipped through the cracks. So the solution here is to reduce the bsz on these variants for your submission. I will also modify the baseline configs to reflect the reduced bsz for these variants.