awslabs / sockeye

Sequence-to-sequence framework with a focus on Neural Machine Translation based on PyTorch
https://awslabs.github.io/sockeye/
Apache License 2.0
1.21k stars 323 forks source link

Factors inference is slow (3 seconds/token) on A100 GPU #1110

Open AmitMY opened 6 months ago

AmitMY commented 6 months ago

My use case calls for splitting my input tokens to 5, and output tokens to 8. That means that the input has a token + 4 factors (SignWriting), and the output has a token + 7 factors (VQ model)

I created factored files for an example sentence: M|c0|r0|p518|p518 S2ff|c0|r0|p482|p483 And attempt to translate, with:

python -m sockeye.translate --models "$MODEL_DIR/unconstrained/model" \
  --input "$MODEL_DIR/unconstrained/sample/source_0.txt" \
  --input-factors "$MODEL_DIR/unconstrained/sample/source_1.txt" "$MODEL_DIR/unconstrained/sample/source_2.txt" "$MODEL_DIR/unconstrained/sample/source_3.txt" "$MODEL_DIR/unconstrained/sample/source_4.txt" \
  --output-type "translation_with_factors" \
  --max-output-length 10 \
  --beam-size=1

And the output is:

[INFO:main] Processed 1 lines. Total time: 29.1466, sec/sent: 29.1466, sent/sec: 0.0343

Why would translating a single sentence, with A100 GPU, on a small model, without beam search, be this slow? Is there a way to profile the decoding step function?


The full output is:

[INFO:sockeye.utils] Sockeye: 3.1.38, commit , path /data/amoryo/conda/envs/sockeye/lib/python3.11/site-packages/sockeye/__init__.py
[INFO:sockeye.utils] PyTorch: 1.13.1+cu117 (/data/amoryo/conda/envs/sockeye/lib/python3.11/site-packages/torch/__init__.py)
[INFO:sockeye.utils] Command: /data/amoryo/conda/envs/sockeye/lib/python3.11/site-packages/sockeye/translate.py --beam-size=1 --models /shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model --input /shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/sample/source_0.txt --input-factors /shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/sample/source_1.txt /shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/sample/source_2.txt /shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/sample/source_3.txt /shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/sample/source_4.txt --output-type translation_with_factors --max-output-length 10
[INFO:sockeye.utils] Arguments: Namespace(config=None, input='/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/sample/source_0.txt', input_factors=['/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/sample/source_1.txt', '/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/sample/source_2.txt', '/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/sample/source_3.txt', '/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/sample/source_4.txt'], json_input=False, output=None, models=['/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model'], checkpoints=None, nbest_size=1, beam_size=1, greedy=False, beam_search_stop='all', batch_size=1, chunk_size=None, sample=None, seed=None, ensemble_mode='linear', bucket_width=10, max_input_length=None, max_output_length_num_stds=2, max_output_length=10, restrict_lexicon=None, restrict_lexicon_topk=None, skip_nvs=False, nvs_thresh=0.5, strip_unknown_words=False, prevent_unk=False, output_type='translation_with_factors', length_penalty_alpha=1.0, length_penalty_beta=0.0, brevity_penalty_type='none', brevity_penalty_weight=1.0, brevity_penalty_constant_length_ratio=0.0, dtype=None, clamp_to_dtype=False, device_id=0, use_cpu=False, env=None, tf32=True, quiet=False, quiet_secondary_workers=False, no_logfile=False, loglevel='INFO', loglevel_secondary_workers='INFO', knn_index=None, knn_lambda=0.8)
[INFO:sockeye.utils] CUDA: allow tf32 (float32 but with 10 bits precision)
[INFO:__main__] Translate Device: cuda:0
[INFO:sockeye.utils] CUDA: allow tf32 (float32 but with 10 bits precision)
[INFO:sockeye.model] Loading 1 model(s) from ['/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model'] ...
[INFO:sockeye.vocab] Vocabulary (664 words) loaded from "/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model/vocab.src.0.json"
[INFO:sockeye.vocab] Vocabulary (16 words) loaded from "/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model/vocab.src.1.json"
[INFO:sockeye.vocab] Vocabulary (24 words) loaded from "/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model/vocab.src.2.json"
[INFO:sockeye.vocab] Vocabulary (504 words) loaded from "/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model/vocab.src.3.json"
[INFO:sockeye.vocab] Vocabulary (504 words) loaded from "/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model/vocab.src.4.json"
[INFO:sockeye.vocab] Vocabulary (1008 words) loaded from "/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model/vocab.trg.0.json"
[INFO:sockeye.vocab] Vocabulary (1008 words) loaded from "/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model/vocab.trg.1.json"
[INFO:sockeye.vocab] Vocabulary (1008 words) loaded from "/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model/vocab.trg.2.json"
[INFO:sockeye.vocab] Vocabulary (1008 words) loaded from "/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model/vocab.trg.3.json"
[INFO:sockeye.vocab] Vocabulary (1008 words) loaded from "/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model/vocab.trg.4.json"
[INFO:sockeye.vocab] Vocabulary (1008 words) loaded from "/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model/vocab.trg.5.json"
[INFO:sockeye.vocab] Vocabulary (1008 words) loaded from "/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model/vocab.trg.6.json"
[INFO:sockeye.vocab] Vocabulary (1008 words) loaded from "/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model/vocab.trg.7.json"
[INFO:sockeye.model] Model version: 3.1.38
[INFO:sockeye.model] Loaded model config from "/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model/config"
[INFO:sockeye.model] ModelConfig(config_data=DataConfig(data_statistics=DataStatistics(num_sents=562592, num_discarded=315, num_tokens_source=1687776, num_tokens_target=57851323, num_unks_source=0, num_unks_target=0, max_observed_len_source=3, max_observed_len_target=513, size_vocab_source=664, size_vocab_target=1008, length_ratio_mean=34.276659343419496, length_ratio_std=14.879263574710128, buckets=[(8, 8), (16, 16), (24, 24), (32, 32), (40, 40), (48, 48), (56, 56), (64, 64), (72, 72), (80, 80), (88, 88), (96, 96), (104, 104), (112, 112), (120, 120), (128, 128), (136, 136), (144, 144), (152, 152), (160, 160), (168, 168), (176, 176), (184, 184), (192, 192), (200, 200), (208, 208), (216, 216), (224, 224), (232, 232), (240, 240), (248, 248), (256, 256), (264, 264), (272, 272), (280, 280), (288, 288), (296, 296), (304, 304), (312, 312), (320, 320), (328, 328), (336, 336), (344, 344), (352, 352), (360, 360), (368, 368), (376, 376), (384, 384), (392, 392), (400, 400), (408, 408), (416, 416), (424, 424), (432, 432), (440, 440), (448, 448), (456, 456), (464, 464), (472, 472), (480, 480), (488, 488), (496, 496), (504, 504), (512, 512), (513, 513)], num_sents_per_bucket=[3, 1, 19, 498, 2678, 10120, 27611, 47985, 57944, 61020, 53583, 44696, 41032, 34378, 30526, 26818, 22831, 18533, 16196, 12340, 9853, 8577, 5816, 4717, 4510, 3553, 2376, 2354, 1616, 1359, 2422, 1121, 749, 856, 604, 410, 405, 468, 282, 224, 157, 190, 162, 102, 131, 142, 74, 61, 55, 69, 57, 41, 43, 28, 33, 28, 32, 20, 16, 14, 15, 18, 9, 10, 1], average_len_target_per_bucket=[2.0, 15.0, 22.47368421052631, 29.895582329317254, 37.32449589245704, 45.26413043478288, 52.99141646445262, 60.725705949775836, 68.54556123153431, 76.37490986561775, 84.53625216953161, 92.42469124753944, 100.45917820237935, 108.42053057187677, 116.43959247854299, 124.43127750018597, 132.41382331041123, 140.36896347056646, 148.29741911583167, 156.2551863857384, 164.24865523190755, 172.1120438381717, 180.24810866574907, 188.2524909900362, 195.9474501108644, 204.0509428651839, 212.37668350168337, 219.66864910790179, 228.24566831683163, 236.7424576894776, 243.63831544178396, 252.09901873327397, 260.41655540720996, 268.0724299065416, 276.22516556291396, 284.2926829268292, 291.80493827160535, 300.95726495726507, 307.92907801418494, 315.73660714285717, 324.45859872611453, 332.38421052631577, 340.08024691358037, 348.57843137254906, 356.75572519083966, 364.0845070422536, 372.59459459459464, 380.78688524590166, 387.92727272727274, 396.3623188405797, 403.9824561403508, 412.4146341463415, 420.4418604651163, 428.28571428571433, 435.8181818181818, 443.99999999999994, 452.18750000000006, 460.04999999999995, 467.625, 476.8571428571429, 483.33333333333326, 493.2222222222223, 499.7777777777777, 508.70000000000005, 513.0], length_ratio_stats_per_bucket=[(0.6666666666666666, 0.0), (5.0, 0.0), (7.4912280701754375, 0.6611032870672552), (9.965194109772419, 0.6773854309889832), (12.441498630819025, 0.7323565209289744), (15.088043478260888, 0.7228695573061236), (17.663805488150842, 0.7477814938148047), (20.24190198325863, 0.7557906951613651), (22.84852041051123, 0.76260899086021), (25.458303288539526, 0.7429680382689798), (28.178750723177036, 0.7725985330583485), (30.808230415846378, 0.7612824824502872), (33.48639273412596, 0.7637406751055946), (36.14017685729224, 0.7633158033079295), (38.813197492847344, 0.748047045959258), (41.47709250006236, 0.7784719198398958), (44.137941103470446, 0.7532059978984165), (46.78965449018851, 0.7594518647223019), (49.43247303861048, 0.7731868633658496), (52.08506212857909, 0.7602201079711013), (54.74955174396979, 0.7721951407116094), (57.37068127939054, 0.7726670130819834), (60.08270288858333, 0.7589571502358675), (62.75083033001211, 0.7573209651045687), (65.31581670362168, 0.8010609138746856), (68.0169809550613, 0.7586793214732782), (70.79222783389453, 0.761428337943305), (73.22288303596699, 0.8127171243097352), (76.08188943894403, 0.7619360833015993), (78.91415256315932, 0.788264227982759), (81.21277181392782, 0.7263146698773092), (84.03300624442463, 0.7446248339832199), (86.80551846906988, 0.7602374250419729), (89.35747663551402, 0.8121537001280452), (92.07505518763794, 0.7250101472791833), (94.76422764227645, 0.7737001598415769), (97.26831275720153, 0.8207845494821012), (100.31908831908838, 0.7219173089327346), (102.64302600472807, 0.7770804608693789), (105.2455357142857, 0.8349566989974272), (108.15286624203819, 0.7994074644452088), (110.79473684210528, 0.7417982858094194), (113.36008230452681, 0.8037791354444079), (116.19281045751636, 0.7503612412672215), (118.91857506361325, 0.6905291735552149), (121.36150234741785, 0.7817289110491428), (124.19819819819818, 0.7472048392700509), (126.92896174863385, 0.7419630195625642), (129.30909090909094, 0.8571848343387725), (132.12077294685983, 0.7634378859958186), (134.6608187134503, 0.7050119324061157), (137.4715447154471, 0.8030281772270442), (140.14728682170545, 0.8665543325570307), (142.76190476190473, 0.8060148038005153), (145.27272727272725, 0.7761362712039772), (148.0, 0.7182430061427759), (150.72916666666663, 0.7702700644723462), (153.35, 0.8463975950396387), (155.87499999999997, 0.7252872993970579), (158.95238095238093, 0.6884205854667195), (161.11111111111106, 0.7568616162633894), (164.4074074074074, 0.733295921230492), (166.59259259259258, 0.6241592424575069), (169.56666666666666, 0.8950481054731718), (171.0, 0.0)]), max_seq_len_source=513, max_seq_len_target=513, num_source_factors=5, num_target_factors=8, eop_id=-1), vocab_source_size=664, vocab_target_size=1008, config_embed_source=EmbeddingConfig(vocab_size=664, num_embed=512, dropout=0.5, num_factors=5, factor_configs=[FactorConfig(vocab_size=16, num_embed=512, combine='sum', share_embedding=False), FactorConfig(vocab_size=24, num_embed=512, combine='sum', share_embedding=False), FactorConfig(vocab_size=504, num_embed=512, combine='sum', share_embedding=False), FactorConfig(vocab_size=504, num_embed=512, combine='sum', share_embedding=False)], allow_sparse_grad=False), config_embed_target=EmbeddingConfig(vocab_size=1008, num_embed=512, dropout=0.5, num_factors=8, factor_configs=[FactorConfig(vocab_size=1008, num_embed=512, combine='sum', share_embedding=False), FactorConfig(vocab_size=1008, num_embed=512, combine='sum', share_embedding=False), FactorConfig(vocab_size=1008, num_embed=512, combine='sum', share_embedding=False), FactorConfig(vocab_size=1008, num_embed=512, combine='sum', share_embedding=False), FactorConfig(vocab_size=1008, num_embed=512, combine='sum', share_embedding=False), FactorConfig(vocab_size=1008, num_embed=512, combine='sum', share_embedding=False), FactorConfig(vocab_size=1008, num_embed=512, combine='sum', share_embedding=False)], allow_sparse_grad=False), config_encoder=TransformerConfig(model_size=512, attention_heads=8, feed_forward_num_hidden=2048, act_type='relu', num_layers=6, dropout_attention=0.2, dropout_act=0.2, dropout_prepost=0.2, positional_embedding_type='fixed', preprocess_sequence='n', postprocess_sequence='dr', max_seq_len_source=513, max_seq_len_target=513, decoder_type='transformer', block_prepended_cross_attention=False, use_lhuc=False, depth_key_value=512, use_glu=False), config_decoder=TransformerConfig(model_size=512, attention_heads=8, feed_forward_num_hidden=2048, act_type='relu', num_layers=6, dropout_attention=0.2, dropout_act=0.2, dropout_prepost=0.2, positional_embedding_type='fixed', preprocess_sequence='n', postprocess_sequence='dr', max_seq_len_source=513, max_seq_len_target=513, decoder_type='transformer', block_prepended_cross_attention=False, use_lhuc=False, depth_key_value=512, use_glu=False), config_length_task=None, weight_tying_type='trg_softmax', lhuc=False, dtype='float32', neural_vocab_selection=None, neural_vocab_selection_block_loss=False)
[INFO:sockeye.model] Loaded params from "/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model/params.best" to "cuda:0"
[INFO:sockeye.model] Model dtype: torch.float32
[INFO:sockeye.model] 1 model(s) loaded in 1.3760s
[INFO:sockeye.beam_search] Enabled skipping softmax for a single model and greedy decoding.
[INFO:sockeye.inference] Translator (1 model(s) beam_size=1 algorithm=BeamSearch, beam_search_stop=all max_input_length=512 nbest_size=1 ensemble_mode=None max_batch_size=1 dtype=torch.float32 skip_nvs=False nvs_thresh=0.5)
[INFO:__main__] Translating...
/data/amoryo/conda/envs/sockeye/lib/python3.11/site-packages/torch/jit/_trace.py:976: TracerWarning: Encountering a list at the output of the tracer might cause the trace to be incorrect, this is only valid if the container structure does not change based on the module's inputs. Consider using a constant container instead (e.g. for `list`, use a `tuple` instead. for `dict`, use a `NamedTuple` instead). If you absolutely need this and know the side effects, pass strict=False to trace() to allow this behavior.
  module._c._create_method_from_trace(
/data/amoryo/conda/envs/sockeye/lib/python3.11/site-packages/torch/nn/modules/module.py:1194: UserWarning: FALLBACK path has been taken inside: runCudaFusionGroup. This is an indication that codegen Failed for some reason.
To debug try disable codegen fallback path via setting the env variable `export PYTORCH_NVFUSER_DISABLE=fallback`
 (Triggered internally at ../torch/csrc/jit/codegen/cuda/manager.cpp:331.)
  return forward_call(*input, **kwargs)
266|507|468|505|698|460|419|418 266|507|468|505|698|460|419|418 266|507|468|505|698|460|419|418 266|507|468|505|698|460|419|418 266|507|468|505|698|460|419|418 266|507|468|505|698|460|419|418 266|507|468|505|698|460|419|418 266|507|468|505|698|460|419|418 266|507|468|505|698|460|419|418 266|507|468|505|698|460|419|418
[INFO:__main__] Processed 1 lines. Total time: 29.1466, sec/sent: 29.1466, sent/sec: 0.0343

Besides the fact that the output repeats the same token over and over, it is in the expected format.

mjdenkowski commented 6 months ago

Hi Amit,

That's a good question. I don't know that anyone has tested Sockeye with that many factors.

One hypothesis would be that the factor code contains cases of switching between Python/C++/GPU execution and looping over more factors leads to greater slowdown. @fhieber may have more information about decoding with factors.

For profiling, you could take a look at the PyTorch Profiler.

Best, Michael

AmitMY commented 6 months ago

Thanks!

One possible improvement I see, is instead of: https://github.com/awslabs/sockeye/blob/main/sockeye/model.py#L665C13-L665C79

To run the multiplications in parallel:

futures = [torch.jit.fork(fol, decoder_out) for fol in self.factor_output_layers]
outputs += [torch.jit.wait(fut) for fut in futures]

Also as a side note, in decoding, it seems like target factors are not embedded: https://github.com/awslabs/sockeye/blob/main/sockeye/model.py#L654 Am I missing something?

AmitMY commented 6 months ago

With the --use-cpu flag, we get

[INFO:main] Processed 1 lines. Total time: 1.6748, sec/sent: 1.6748, sent/sec: 0.5971

Compared to an A100 GPU:

[INFO:main] Processed 1 lines. Total time: 29.1466, sec/sent: 29.1466, sent/sec: 0.0343

AmitMY commented 6 months ago

Since it seems like the CPU time is huge, I list the CPU timing:

Self CPU time total: 18.575s
Self CUDA time total: 27.326ms

Profile output:

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                forward        96.79%       17.978s        97.03%       18.023s     667.534ms       0.000us         0.00%      12.625ms     467.593us            27  
                                           aten::linear         0.06%      11.700ms         2.21%     410.557ms     896.413us       0.000us         0.00%      13.302ms      29.044us           458  
                                           aten::matmul         0.01%       1.582ms         2.09%     388.146ms       1.578ms       0.000us         0.00%       7.126ms      28.967us           246  
                                               aten::mm         0.05%       9.657ms         2.08%     385.787ms       1.568ms       6.661ms        24.38%       7.126ms      28.967us           246  
                                               cudaFree         2.01%     373.293ms         2.01%     373.293ms     186.647ms     112.000us         0.41%     112.000us      56.000us             2  
                                aten::repeat_interleave         0.03%       5.207ms         0.17%      32.192ms     185.011us     398.000us         1.46%       6.604ms      37.954us           174  
                                       cudaLaunchKernel         0.13%      24.834ms         0.13%      24.834ms       8.611us       1.284ms         4.70%       1.284ms       0.445us          2884  
                                          aten::dropout         0.12%      22.784ms         0.12%      22.784ms      69.463us       0.000us         0.00%       0.000us       0.000us           328  
                                       aten::layer_norm         0.08%      15.266ms         0.11%      20.798ms     127.595us       0.000us         0.00%       1.614ms       9.902us           163  
                                            aten::slice         0.08%      14.190ms         0.08%      14.244ms      11.840us       0.000us         0.00%       0.000us       0.000us          1203  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  

Here is a profile file, to be opened in chrome://tracing trace.json

AmitMY commented 6 months ago

with torch 2.3.0, on GPU:

[INFO:main] Processed 1 lines. Total time: 2.8488, sec/sent: 2.8488, sent/sec: 0.3510

on CPU:

[INFO:main] Processed 1 lines. Total time: 1.6967, sec/sent: 1.6967, sent/sec: 0.5894

why is sockeye restricted to torch 1?

mjdenkowski commented 6 months ago

The torch version in Sockeye's requirements.txt (currently torch>=1.10.0,<1.14.0) indicates the latest version of PyTorch that Sockeye is officially tested with.

If you change the line to just torch, you can test Sockeye with the current version of PyTorch.

Best, Michael