mlcommons / training

Reference implementations of MLPerf™ training benchmarks
https://mlcommons.org/en/groups/training
Apache License 2.0
1.62k stars 561 forks source link

[DLRM v2] How to modify the default training script of DLRM v2 to train the model with limited GPU memory #655

Open JJingL opened 1 year ago

JJingL commented 1 year ago

Hi Teams, I have run the default training script of DLRM v2 to train the model, however, the GPU I used doesn't have enough memory for the default setting. I just modified the training script with the following changes: --num_embeddings_per_feature 26000000,39060,17295,7424,20265,3,7122,1543,63,26000000,3067956,405282,10,2209,11938,155,4,976,14,26000000,26000000,26000000,590152,12973,108,36 \

However,the output of eval_accuracy didn't increase, and the final result is around 0.70x. Is there anyone have any idea?

Ps:Here is the exact command I've tried: export TOTAL_TRAINING_SAMPLES=4195197692 export GLOBAL_BATCH_SIZE=16384 export WORLD_SIZE=8

torchx run -s local_cwd dist.ddp -j 1x8 --script dlrm_main.py -- \ --embedding_dim 128 \ --dense_arch_layer_sizes 512,256,128 \ --over_arch_layer_sizes 1024,1024,512,256,1 \ --in_memory_binary_criteo_path /workspace/DLRM/numpy_contiguous_shuffled_output_dataset_dir \ --num_embeddings_per_feature 24000000,39060,17295,7424,20265,3,7122,1543,63,24000000,3067956,405282,10,2209,11938,155,4,976,14,24000000,24000000,24000000,590152,12973,108,36 \ --validation_freq_within_epoch $((TOTAL_TRAINING_SAMPLES / (GLOBAL_BATCH_SIZE * 40))) \ --epochs 1 \ --adagrad \ --pin_memory \ --mmap_mode \ --batch_size $((GLOBAL_BATCH_SIZE / WORLD_SIZE)) \ --interaction_type=dcn \ --dcn_num_layers=3 \ --dcn_low_rank_dim=512 \ --learning_rate 0.004 \ --shuffle_batches \ --multi_hot_distribution_type uniform \ --multi_hot_sizes=3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1 \ --print_sharding_plan

menghongtao commented 6 months ago

Hi I am also faced the problem same with you, I changed the embeddingdim ,and the result auc of eval is 0.70*, have you solved your issued