pytorch / torchrec

Pytorch domain library for recommendation systems
https://pytorch.org/torchrec/
BSD 3-Clause "New" or "Revised" License
1.92k stars 423 forks source link

RuntimeError & questions about bert4rec example code #1057

Open nippleshot opened 1 year ago

nippleshot commented 1 year ago

Hello, While I try to run bert4rec example code, I have faced 2 problems and I hope I can get some feedback.

1) When I try to run bert4rec model with multiple TITAN XP GPUs, it shows following RuntimeError.

(btw I used MovieLens 1M dataset)

2) When I try to run bert4rec model with single TITAN XP GPU, it runs without any RuntimeError. However, The model metric I got was too low compared to Recommendation Metrics Reproduce. I wonder why this kind of situation happens.

Epoch 1, metrics {'Recall@1': 0.0126953125, 'Recall@5': 0.0634765625, 'Recall@10': 0.12141927083333333, 'NDCG@5': 0.03733245619029427, 'NDCG@10': 0.05605129435813675} Epoch 50, metrics {'Recall@1': 0.01220703125, 'Recall@5': 0.06711154516475897, 'Recall@10': 0.12320963541666667, 'NDCG@5': 0.039204553118906915, 'NDCG@10': 0.05708927156714102} Epoch 100, metrics {'Recall@1': 0.010904947916666666, 'Recall@5': 0.05805121532951792, 'Recall@10': 0.11387803824618459, 'NDCG@5': 0.033518949057906866, 'NDCG@10': 0.05138459533918649} Epoch 150, metrics {'Recall@1': 0.01123046875, 'Recall@5': 0.06141493058142563, 'Recall@10': 0.11474609375, 'NDCG@5': 0.035709500865777954, 'NDCG@10': 0.052741181144180395} Epoch 190, metrics {'Recall@1': 0.014051649331425628, 'Recall@5': 0.06287977433142562, 'Recall@10': 0.11404079866285126, 'NDCG@5': 0.03795023405109532, 'NDCG@10': 0.05423267767764628} ...


* As you can see from above, average loss doesn't drop drastically in the beginning of training stage and it seems like training is not going well.
* I tried with same [argument value shown in test method](https://github.com/pytorch/torchrec/blob/ec85e941de8c7dea65f427d514542241e3d48556/examples/bert4rec/tests/test_bert4rec_main.py#L183), and average loss and metrics for each epoch doesn't show big changes
YLGH commented 1 year ago

Hmm first issue, seems like potential edge case for recent change to input_dist @joshuadeng do you know off top of head?

second issue looks like settings aren't exactly the same? num_epochs+batch_size etc

nippleshot commented 1 year ago

@YLGH For the 2nd issue, When I had a same argument setting w/ this method, average loss and metrics for each epoch are almost stable. Using default argument shows little bit better to me, but neither of them is well trained.

Epoch 1, metrics {'Recall@1': 0.012763843250771364, 'Recall@5': 0.060915227668980755, 'Recall@10': 0.11845531811316808, 'NDCG@5': 0.036595141515135765, 'NDCG@10': 0.05504566555221876} Epoch 10, metrics {'Recall@1': 0.016010485201453168, 'Recall@5': 0.06351082772016525, 'Recall@10': 0.11540570172170798, 'NDCG@5': 0.039478599869956575, 'NDCG@10': 0.05607946729287505} Epoch 20, metrics {'Recall@1': 0.01390316616743803, 'Recall@5': 0.059990062999228634, 'Recall@10': 0.1187722726414601, 'NDCG@5': 0.036320349046339594, 'NDCG@10': 0.055121896167596184} Epoch 30, metrics {'Recall@1': 0.01286663922170798, 'Recall@5': 0.06471011508256197, 'Recall@10': 0.12051123877366383, 'NDCG@5': 0.03837353542136649, 'NDCG@10': 0.05636422669825455}

joshuadeng commented 1 year ago

Hi @nippleshot, I'm unable to reproduce your error on a setup with V100 GPUs. Can you try running on random data to see if you get this error as well?

nippleshot commented 1 year ago

Hello @joshuadeng,

These are the results for "random dataset" with "mode=dmp" :

torchx run -s local_cwd dist.ddp -j 1x8 --script bert4rec_main.py -- --dataset_name random --mode dmp
```c ... bert4rec_main/0 [0]:Traceback (most recent call last): bert4rec_main/0 [0]: File "/torchrec/examples/bert4rec/bert4rec_main.py", line 565, in bert4rec_main/0 [0]: main(sys.argv[1:]) bert4rec_main/0 [0]: File "/torchrec/examples/bert4rec/bert4rec_main.py", line 548, in main bert4rec_main/0 [0]: train_val_test( bert4rec_main/0 [0]: File "/torchrec/examples/bert4rec/bert4rec_main.py", line 391, in train_val_test bert4rec_main/0 [0]: _validate(model, val_loader, device, -1, metric_ks) bert4rec_main/0 [0]: File "/torchrec/examples/bert4rec/bert4rec_main.py", line 344, in _validate bert4rec_main/0 [0]: metrics = _calculate_metrics(model, batch, metric_ks, device) bert4rec_main/0 [0]: File "/torchrec/examples/bert4rec/bert4rec_main.py", line 242, in _calculate_metrics bert4rec_main/0 [0]: scores = model(kjt) bert4rec_main/0 [0]: File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1482, in _call_impl bert4rec_main/0 [0]: return forward_call(*args, **kwargs) bert4rec_main/0 [0]: File "/opt/conda/lib/python3.10/site-packages/torchrec/distributed/model_parallel.py", line 250, in forward bert4rec_main/0 [0]: return self._dmp_wrapped_module(*args, **kwargs) bert4rec_main/0 [0]: File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1482, in _call_impl bert4rec_main/0 [0]: return forward_call(*args, **kwargs) bert4rec_main/0 [0]: File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1152, in forward bert4rec_main/0 [0]: output = self._run_ddp_forward(*inputs, **kwargs) bert4rec_main/0 [0]: File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1105, in _run_ddp_forward bert4rec_main/0 [0]: return module_to_run(*inputs[0], **kwargs[0]) # type: ignore[index] bert4rec_main/0 [0]: File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1482, in _call_impl bert4rec_main/0 [0]: return forward_call(*args, **kwargs) bert4rec_main/0 [0]: File "/torchrec/examples/bert4rec/models/bert4rec.py", line 493, in forward bert4rec_main/0 [0]: x = self.history(input) bert4rec_main/0 [0]: File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1482, in _call_impl bert4rec_main/0 [0]: return forward_call(*args, **kwargs) bert4rec_main/0 [0]: File "/torchrec/examples/bert4rec/models/bert4rec.py", line 392, in forward bert4rec_main/0 [0]: jt_dict = self.ec(id_list_features) bert4rec_main/0 [0]: File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1482, in _call_impl bert4rec_main/0 [0]: return forward_call(*args, **kwargs) bert4rec_main/0 [0]: File "/opt/conda/lib/python3.10/site-packages/torchrec/distributed/types.py", line 594, in forward bert4rec_main/0 [0]: dist_input = self.input_dist(ctx, *input, **kwargs).wait().wait() bert4rec_main/0 [0]: File "/opt/conda/lib/python3.10/site-packages/torchrec/distributed/types.py", line 222, in wait bert4rec_main/0 [0]: ret: W = self._wait_impl() bert4rec_main/0 [0]: File "/opt/conda/lib/python3.10/site-packages/torchrec/distributed/embedding_sharding.py", line 269, in _wait_impl bert4rec_main/0 [0]: tensors_awaitables = [w.wait() for w in self.awaitables] bert4rec_main/0 [0]: File "/opt/conda/lib/python3.10/site-packages/torchrec/distributed/embedding_sharding.py", line 269, in bert4rec_main/0 [0]: tensors_awaitables = [w.wait() for w in self.awaitables] bert4rec_main/0 [0]: File "/opt/conda/lib/python3.10/site-packages/torchrec/distributed/types.py", line 222, in wait bert4rec_main/0 [0]: ret: W = self._wait_impl() bert4rec_main/0 [0]: File "/opt/conda/lib/python3.10/site-packages/torchrec/distributed/dist_data.py", line 349, in _wait_impl bert4rec_main/0 [0]: return KJTAllToAllTensorsAwaitable( bert4rec_main/0 [0]: File "/opt/conda/lib/python3.10/site-packages/torchrec/distributed/dist_data.py", line 243, in __init__ bert4rec_main/0 [0]: awaitable = dist.all_to_all_single( bert4rec_main/0 [0]: File "/opt/conda/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 1422, in wrapper bert4rec_main/0 [0]: return func(*args, **kwargs) bert4rec_main/0 [0]: File "/opt/conda/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 3128, in all_to_all_single bert4rec_main/0 [0]: work = group.alltoall_base( bert4rec_main/0 [0]:RuntimeError: Split sizes doesn't match total dim 0 size ... torchx 2023-03-07 03:03:02 INFO Job finished: FAILED ```
torchx run -s local_cwd dist.ddp -j 1x2 --script bert4rec_main.py -- --dataset_name random --mode dmp
```c Epoch 1, average loss 3.7423146963119507 Epoch 50, average loss 2.0246068835258484 Epoch 100, average loss 1.7793037295341492 Epoch 1, metrics {'Recall@1': 0.5000000149011612, 'Recall@5': 0.7000000178813934, 'Recall@10': 0.9000000059604645, 'NDCG@5': 0.6130929589271545, 'NDCG@10': 0.6776201725006104} Epoch 50, metrics {'Recall@1': 0.4000000059604645, 'Recall@5': 0.7000000178813934, 'Recall@10': 0.9000000059604645, 'NDCG@5': 0.5492282956838608, 'NDCG@10': 0.6181823760271072} Epoch 100, metrics {'Recall@1': 0.30000001192092896, 'Recall@5': 0.5000000074505806, 'Recall@10': 0.9000000059604645, 'NDCG@5': 0.4130929708480835, 'NDCG@10': 0.5425000041723251} Test, metrics {'Recall@1': 0.0, 'Recall@5': 0.30000000447034836, 'Recall@10': 0.800000011920929, 'NDCG@5': 0.13613532111048698, 'NDCG@10': 0.2952927201986313} torchx 2023-03-07 02:48:40 INFO Job finished: SUCCEEDED ```
torchx run -s local_cwd dist.ddp -j 1x1 --script bert4rec_main.py -- --dataset_name random --mode dmp
```c Epoch 1, average loss 3.7459120750427246 Epoch 50, average loss 1.6112332344055176 Epoch 100, average loss 1.4439857006072998 Epoch 1, metrics {'Recall@1': 0.4000000059604645, 'Recall@5': 0.6000000238418579, 'Recall@10': 1.0, 'NDCG@5': 0.4930676519870758, 'NDCG@10': 0.6198346614837646} Epoch 50, metrics {'Recall@1': 0.30000001192092896, 'Recall@5': 0.5, 'Recall@10': 0.9000000357627869, 'NDCG@5': 0.4000000059604645, 'NDCG@10': 0.5314474105834961} Epoch 100, metrics {'Recall@1': 0.20000000298023224, 'Recall@5': 0.5, 'Recall@10': 1.0, 'NDCG@5': 0.35616064071655273, 'NDCG@10': 0.5185484290122986} Test, metrics {'Recall@1': 0.10000000149011612, 'Recall@5': 0.6000000238418579, 'Recall@10': 0.9000000357627869, 'NDCG@5': 0.34922829270362854, 'NDCG@10': 0.44530197978019714} torchx 2023-03-07 02:44:57 INFO Job finished: SUCCEEDED ```
torchx run -s local_cwd dist.ddp -j 1x8 --script bert4rec_main.py -- --dataset_name random --random_user_count 3000 --random_item_count 500000 --random_size 8000000 --lr 0.001 --mask_prob 0.2 --train_batch_size 8 --val_batch_size 8 --max_len 16 --emb_dim 128 --num_epochs 2 --mode dmp
```c ... bert4rec_main/0 [5]:Traceback (most recent call last): bert4rec_main/0 [5]: File "/torchrec/examples/bert4rec/bert4rec_main.py", line 565, in bert4rec_main/0 [5]: main(sys.argv[1:]) bert4rec_main/0 [5]: File "/torchrec/examples/bert4rec/bert4rec_main.py", line 495, in main bert4rec_main/0 [5]: model = DMP( bert4rec_main/0 [5]: File "/opt/conda/lib/python3.10/site-packages/torchrec/distributed/model_parallel.py", line 220, in __init__ bert4rec_main/0 [5]: self._dmp_wrapped_module: nn.Module = self._init_dmp(module) bert4rec_main/0 [5]: File "/opt/conda/lib/python3.10/site-packages/torchrec/distributed/model_parallel.py", line 278, in _init_dmp bert4rec_main/0 [5]: return self._shard_modules_impl(module) bert4rec_main/0 [5]: File "/opt/conda/lib/python3.10/site-packages/torchrec/distributed/model_parallel.py", line 325, in _shard_modules_impl bert4rec_main/0 [5]: child = self._shard_modules_impl( bert4rec_main/0 [5]: File "/opt/conda/lib/python3.10/site-packages/torchrec/distributed/model_parallel.py", line 325, in _shard_modules_impl bert4rec_main/0 [5]: child = self._shard_modules_impl( bert4rec_main/0 [5]: File "/opt/conda/lib/python3.10/site-packages/torchrec/distributed/model_parallel.py", line 316, in _shard_modules_impl bert4rec_main/0 [5]: module = self._sharder_map[sharder_key].shard( bert4rec_main/0 [5]: File "/opt/conda/lib/python3.10/site-packages/torchrec/distributed/embedding.py", line 713, in shard bert4rec_main/0 [5]: return ShardedEmbeddingCollection( bert4rec_main/0 [5]: File "/opt/conda/lib/python3.10/site-packages/torchrec/distributed/embedding.py", line 337, in __init__ bert4rec_main/0 [5]: self._create_lookups() bert4rec_main/0 [5]: File "/opt/conda/lib/python3.10/site-packages/torchrec/distributed/embedding.py", line 577, in _create_lookups bert4rec_main/0 [5]: self._lookups.append(sharding.create_lookup()) bert4rec_main/0 [5]: File "/opt/conda/lib/python3.10/site-packages/torchrec/distributed/sharding/rw_sequence_sharding.py", line 122, in create_lookup bert4rec_main/0 [5]: return GroupedEmbeddingsLookup( bert4rec_main/0 [5]: File "/opt/conda/lib/python3.10/site-packages/torchrec/distributed/embedding_lookup.py", line 120, in __init__ bert4rec_main/0 [5]: self._emb_modules.append(_create_lookup(config)) bert4rec_main/0 [5]: File "/opt/conda/lib/python3.10/site-packages/torchrec/distributed/embedding_lookup.py", line 107, in _create_lookup bert4rec_main/0 [5]: return BatchedFusedEmbedding( bert4rec_main/0 [5]: File "/opt/conda/lib/python3.10/site-packages/torchrec/distributed/batched_embedding_kernel.py", line 490, in __init__ bert4rec_main/0 [5]: self._optim: EmbeddingFusedOptimizer = EmbeddingFusedOptimizer( bert4rec_main/0 [5]: File "/opt/conda/lib/python3.10/site-packages/torchrec/distributed/batched_embedding_kernel.py", line 197, in __init__ bert4rec_main/0 [5]: weight = ShardedTensor._init_from_local_shards_and_global_metadata( bert4rec_main/0 [5]: File "/opt/conda/lib/python3.10/site-packages/torch/distributed/_shard/sharded_tensor/api.py", line 922, in _init_from_local_shards_and_global_metadata bert4rec_main/0 [5]: sharded_tensor = super( bert4rec_main/0 [5]: File "/opt/conda/lib/python3.10/site-packages/torch/distributed/_shard/sharded_tensor/api.py", line 232, in _init_from_local_shards_and_global_metadata bert4rec_main/0 [5]: check_tensor(shards_metadata, list(sharded_tensor_metadata.size)) bert4rec_main/0 [5]: File "/opt/conda/lib/python3.10/site-packages/torch/distributed/_shard/sharding_spec/_internals.py", line 151, in check_tensor bert4rec_main/0 [5]: raise ValueError( bert4rec_main/0 [5]:ValueError: Total volume of shards: 64000128 does not match tensor volume: 64000256, in other words all the individual shards do not cover the entire tensor ... ```

When I run "random dataset" with "mode=ddp" I can use multi-GPU :

torchx run -s local_cwd dist.ddp -j 1x8 --script bert4rec_main.py -- --dataset_name random --mode ddp
```c Epoch 1, average loss 3.7946859896183014 Epoch 50, average loss 2.9830795228481293 Epoch 100, average loss 2.832148551940918 Epoch 1, metrics {'Recall@1': 0.1875, 'Recall@5': 0.375, 'Recall@10': 0.8125, 'NDCG@5': 0.28510039672255516, 'NDCG@10': 0.42677366733551025} Epoch 50, metrics {'Recall@1': 0.0, 'Recall@5': 0.5625, 'Recall@10': 0.875, 'NDCG@5': 0.2791401036083698, 'NDCG@10': 0.38196960277855396} Epoch 100, metrics {'Recall@1': 0.125, 'Recall@5': 0.5, 'Recall@10': 0.875, 'NDCG@5': 0.30653970688581467, 'NDCG@10': 0.4270668085664511} Test, metrics {'Recall@1': 0.125, 'Recall@5': 0.375, 'Recall@10': 0.8125, 'NDCG@5': 0.2467787005007267, 'NDCG@10': 0.3932377118617296} torchx 2023-03-07 03:07:43 INFO Job finished: SUCCEEDED ```
torchx run -s local_cwd dist.ddp -j 1x1 --script bert4rec_main.py -- --dataset_name random --mode ddp
```c Epoch 1, average loss 3.6442973613739014 Epoch 50, average loss 1.3965630531311035 Epoch 100, average loss 1.3751325607299805 Epoch 1, metrics {'Recall@1': 0.10000000149011612, 'Recall@5': 0.4000000059604645, 'Recall@10': 0.9000000357627869, 'NDCG@5': 0.23175294697284698, 'NDCG@10': 0.3853926360607147} Epoch 50, metrics {'Recall@1': 0.30000001192092896, 'Recall@5': 0.4000000059604645, 'Recall@10': 0.800000011920929, 'NDCG@5': 0.3386852741241455, 'NDCG@10': 0.4610254466533661} Epoch 100, metrics {'Recall@1': 0.30000001192092896, 'Recall@5': 0.4000000059604645, 'Recall@10': 0.9000000357627869, 'NDCG@5': 0.3386852741241455, 'NDCG@10': 0.4937684237957001} Test, metrics {'Recall@1': 0.20000000298023224, 'Recall@5': 0.5, 'Recall@10': 0.800000011920929, 'NDCG@5': 0.3317529261112213, 'NDCG@10': 0.43454083800315857} torchx 2023-03-07 03:13:15 INFO Job finished: SUCCEEDED ```

(seems like average loss value getting smaller when I run with random dataset)

So I tried "ML-1m dataset" & "ML-20m dataset" with "mode=ddp", but still the model metrics don't show big changes while training :

torchx run -s local_cwd dist.ddp -j 1x8 --script bert4rec_main.py -- --dataset_name ml-1m --dataset_path /datasets/ml-1m --lr 0.001 --mask_prob 0.2 --weight_decay 0.00001 --train_batch_size 256 --val_batch_size 256 --test_batch_size 256 --max_len 100 --emb_dim 256 --num_epochs 30 --mode ddp
```c Epoch 1, average loss 7.683864032174205 Epoch 10, average loss 7.553736465317862 Epoch 20, average loss 7.54727636195801 Epoch 30, average loss 7.5460506566278225 Epoch 1, metrics {'Recall@1': 0.010846675645249586, 'Recall@5': 0.057607140081624195, 'Recall@10': 0.11378158659984668, 'NDCG@5': 0.03428887668997049, 'NDCG@10': 0.0523198526352644} Epoch 10, metrics {'Recall@1': 0.013067049245970946, 'Recall@5': 0.06541963992640376, 'Recall@10': 0.12494507618248464, 'NDCG@5': 0.03886623719396691, 'NDCG@10': 0.05785229227816065} Epoch 20, metrics {'Recall@1': 0.012046949918537091, 'Recall@5': 0.06313228692548971, 'Recall@10': 0.12327393470332026, 'NDCG@5': 0.03669133122699956, 'NDCG@10': 0.05595556739717722} Epoch 30, metrics {'Recall@1': 0.01370067618942509, 'Recall@5': 0.06479472030575076, 'Recall@10': 0.12531413277611136, 'NDCG@5': 0.03884579039489229, 'NDCG@10': 0.058175960633282855} Test, metrics {'Recall@1': 0.01587751298211515, 'Recall@5': 0.08367157944788535, 'Recall@10': 0.13462027783195177, 'NDCG@5': 0.04960138816386461, 'NDCG@10': 0.06602672704805931} torchx 2023-03-07 03:33:42 INFO Job finished: SUCCEEDED ```
torchx run -s local_cwd dist.ddp -j 1x8 --script bert4rec_main.py -- --dataset_name ml-20m --dataset_path /datasets/ml-20m --lr 0.001 --mask_prob 0.2 --weight_decay 0.00001 --train_batch_size 64 --val_batch_size 64 --max_len 200 --emb_dim 64 --num_epochs 10 --mode ddp
```c Epoch 1, average loss 8.366310296969377 Epoch 5, average loss 8.330542541664412 Epoch 10, average loss 8.32964715178926 Epoch 1, metrics {'Recall@1': 0.021361854243542436, 'Recall@5': 0.07934300046125461, 'Recall@10': 0.13430437038745388, 'NDCG@5': 0.050259967119780084, 'NDCG@10': 0.06783557287718056} Epoch 5, metrics {'Recall@1': 0.022421298431734314, 'Recall@5': 0.08239160516605167, 'Recall@10': 0.1408195917896679, 'NDCG@5': 0.05224268066798203, 'NDCG@10': 0.07095942989561355} Epoch 10, metrics {'Recall@1': 0.02205373616236162, 'Recall@5': 0.08266547509225092, 'Recall@10': 0.1418213791512915, 'NDCG@5': 0.05196374788754912, 'NDCG@10': 0.0709183489503442} Test, metrics {'Recall@1': 0.020960190716911763, 'Recall@5': 0.08042997472426472, 'Recall@10': 0.13665412454044118, 'NDCG@5': 0.05045869417688877, 'NDCG@10': 0.06848290052114274} torchx 2023-03-07 06:18:57 INFO Job finished: SUCCEEDED ```

I used vanilla code and I still don't get why models executed in my environment still differ a lot in performance (base on ml-1m & ml-20m dataset).