k2-fsa / snowfall

Moved to https://github.com/k2-fsa/icefall
Apache License 2.0
143 stars 42 forks source link

Convergence & decoding problem #170

Open zhichaowang opened 3 years ago

zhichaowang commented 3 years ago

Hi guys, I have installed the latest lhotse, k2 and snowfall and run the 100h librispeech example using mmi_att_transformer_train.py. The model seems does not converge well as the valid objf no longer drops after the second epoch. And during decoding, there is an error. Following are the training log and decoding error:

==============================================================
Training log:
model_path: exp-conformer-noam-mmi-att-musan-sa/epoch-1.pt
epoch: 1
learning rate: 0.0007081674413583273
objf: 1.8649312054523002
best objf: 2.1748975509666995
valid objf: 2.6318537894212493
best valid objf: 2.104838694051142
best epoch: 0
model_path: exp-conformer-noam-mmi-att-musan-sa/epoch-2.pt
epoch: 2
learning rate: 0.0006982041823453956
objf: 1.82354900688806
best objf: 1.82354900688806
valid objf: 1.7432625692363357
best valid objf: 1.7432625692363357
best epoch: 2
model_path: exp-conformer-noam-mmi-att-musan-sa/epoch-3.pt
epoch: 3
learning rate: 0.0005700931854704197
objf: 1.8225307280371177
best objf: 1.82354900688806
valid objf: 1.7847899278731816
best valid objf: 1.7432625692363357
best epoch: 2
model_path: exp-conformer-noam-mmi-att-musan-sa/epoch-4.pt
epoch: 4
learning rate: 0.000493735721369404
objf: 1.8221216481449267
best objf: 1.8221216481449267
valid objf: 1.7376060413361014
best valid objf: 1.7376060413361014
best epoch: 4
model_path: exp-conformer-noam-mmi-att-musan-sa/epoch-5.pt
epoch: 5
learning rate: 0.0004416106543607532
objf: 1.8229302839085555
best objf: 1.8221216481449267
valid objf: 1.892576669785212
best valid objf: 1.7376060413361014
best epoch: 4
model_path: exp-conformer-noam-mmi-att-musan-sa/epoch-6.pt
epoch: 6
learning rate: 0.0004031251426055444
objf: 1.8223825581813689
best objf: 1.8221216481449267
valid objf: 1.807029839322421
best valid objf: 1.7376060413361014
best epoch: 4
model_path: exp-conformer-noam-mmi-att-musan-sa/epoch-7.pt
epoch: 7
learning rate: 0.00037322912346643487
objf: 1.8228606284666071
best objf: 1.8221216481449267
valid objf: 1.8653112919027355
best valid objf: 1.7376060413361014
best epoch: 4
===============================================

Decoding error:
CUDA_VISIBLE_DEVICES=0 python3 ./mmi_att_transformer_decode.py --epoch=9 --avg=1 --use-lm-rescoring=1 --max-duration=30 --num-path=10 
2021-04-21 17:25:16,133 DEBUG [mmi_att_transformer_decode.py:256] About to load model
2021-04-21 17:25:16,337 INFO [common.py:62] load checkpoint from exp-conformer-noam-mmi-att-musan-sa/epoch-8.pt
2021-04-21 17:25:19,530 DEBUG [mmi_att_transformer_decode.py:319] Loading pre-compiled HLG
2021-04-21 17:25:24,036 INFO [mmi_att_transformer_decode.py:327] Rescoring with n-best list, n is 10
2021-04-21 17:25:24,083 DEBUG [mmi_att_transformer_decode.py:344] Loading pre-compiled G_4_gram.pt
2021-04-21 17:25:24,086 DEBUG [mmi_att_transformer_decode.py:358] convert HLG to device
2021-04-21 17:25:25,518 DEBUG [librispeech.py:55] About to get test cuts
2021-04-21 17:25:25,650 DEBUG [librispeech.py:55] About to get test cuts
2021-04-21 17:25:25,731 DEBUG [asr_datamodule.py:198] About to create test dataset
2021-04-21 17:25:25,748 DEBUG [asr_datamodule.py:204] About to create test dataloader
2021-04-21 17:25:25,748 DEBUG [asr_datamodule.py:198] About to create test dataset
2021-04-21 17:25:25,766 DEBUG [asr_datamodule.py:204] About to create test dataloader
2021-04-21 17:25:25,766 INFO [mmi_att_transformer_decode.py:371] * DECODING: test-clean
/anaconda3/lib/python3.8/site-packages/lhotse/dataset/sampling.py:303: UserWarning: The first cut drawn in batch collection violates the max_frames or max_cuts constraints - we'll return it anyway. Consider increasing max_frames/max_cuts.
  warnings.warn("The first cut drawn in batch collection violates the max_frames or max_cuts "
Traceback (most recent call last):
  File "./mmi_att_transformer_decode.py", line 398, in <module>
    main()
  File "./mmi_att_transformer_decode.py", line 373, in main
    results = decode(dataloader=test_dl,
  File "./mmi_att_transformer_decode.py", line 79, in decode
    best_paths = decode_with_lm_rescoring(
  File "/anaconda3/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/snowfall/snowfall/decoding/lm_rescore.py", line 331, in decode_with_lm_rescoring
    return rescore_with_n_best_list(lats, G, num_paths)
  File "/anaconda3/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/snowfall/snowfall/decoding/lm_rescore.py", line 190, in rescore_with_n_best_list
    lm_path_lats = _intersect_device(G,
  File "/snowfall/snowfall/decoding/lm_rescore.py", line 24, in _intersect_device
    return k2.intersect_device(a_fsas,
  File "/anaconda3/lib/python3.8/site-packages/k2/fsa_algo.py", line 175, in intersect_device
    + index_select(b_value, b_arc_map)
  File "/anaconda3/lib/python3.8/site-packages/k2/ops.py", line 134, in index_select
    ans = _IndexSelectFunction.apply(src, index)
  File "/anaconda3/lib/python3.8/site-packages/k2/ops.py", line 49, in forward
    return _k2.index_select(src, index)
RuntimeError: CUDA error: invalid argument
Exception raised from getDeviceFromPtr at /pytorch/aten/src/ATen/cuda/CUDADevice.h:13 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x7ffae0ea62f2 in /anaconda3/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x5b (0x7ffae0ea367b in /anaconda3/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #2: <unknown function> + 0xe70136 (0x7ffae2189136 in /anaconda3/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so)
frame #3: <unknown function> + 0x2cb81 (0x7ffa78f9bb81 in /anaconda3/lib/python3.8/site-packages/_k2.cpython-38-x86_64-linux-gnu.so)
frame #4: <unknown function> + 0xb9644 (0x7ffa79028644 in /anaconda3/lib/python3.8/site-packages/_k2.cpython-38-x86_64-linux-gnu.so)
frame #5: <unknown function> + 0x67440 (0x7ffa78fd6440 in /anaconda3/lib/python3.8/site-packages/_k2.cpython-38-x86_64-linux-gnu.so)
frame #6: <unknown function> + 0x69b22 (0x7ffa78fd8b22 in /anaconda3/lib/python3.8/site-packages/_k2.cpython-38-x86_64-linux-gnu.so)
frame #7: <unknown function> + 0x6d25d (0x7ffa78fdc25d in /anaconda3/lib/python3.8/site-packages/_k2.cpython-38-x86_64-linux-gnu.so)
frame #8: <unknown function> + 0x24697 (0x7ffa78f93697 in /anaconda3/lib/python3.8/site-packages/_k2.cpython-38-x86_64-linux-gnu.so)
<omitting python frames>
frame #14: THPFunction_apply(_object*, _object*) + 0x8fd (0x7ffb2dc728ed in /anaconda3/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #54: __libc_start_main + 0xf3 (0x7ffb304ce6a3 in /lib64/libc.so.6)
danpovey commented 3 years ago

@csukuangfj will make a PR to-morrow or the next day that should help the convergence, by using a teacher model. Did you run the tests in k2? (cd build; ctest or "make test")? That looks to me like a compilation/build problem.

danpovey commented 3 years ago

... actually there is a flag to turn off the LM rescoring, I believe. That is a new feature, it may be buggy.
But since the model didn't converge, I doubt decoding will work.

zhichaowang commented 3 years ago

@csukuangfj will make a PR to-morrow or the next day that should help the convergence, by using a teacher model. Did you run the tests in k2? (cd build; ctest or "make test")? That looks to me like a compilation/build problem.

You are right,I run the tests and some are failed: The following tests FAILED: 1 - Test.Cuda.cu_algorithms_test (Not Run) 2 - Test.Cuda.cu_array_ops_test (Not Run) 3 - Test.Cuda.cu_array_test (Not Run) 4 - Test.Cuda.cu_dtype_test (Not Run) 5 - Test.Cuda.cu_fsa_algo_test (Not Run) 6 - Test.Cuda.cu_fsa_test (Not Run) 7 - Test.Cuda.cu_fsa_utils_test (Not Run) 8 - Test.Cuda.cu_hash_test (Not Run) 9 - Test.Cuda.cu_host_shim_test (Not Run) 10 - Test.Cuda.cu_intersect_test (Not Run) 11 - Test.Cuda.cu_log_test (Not Run) 12 - Test.Cuda.cu_macros_test (Not Run) 13 - Test.Cuda.cu_math_test (Not Run) 14 - Test.Cuda.cu_nvtx_test (Not Run) 15 - Test.Cuda.cu_pinned_context_test (Not Run) 16 - Test.Cuda.cu_ragged_shape_test (Not Run) 17 - Test.Cuda.cu_ragged_test (Not Run) 18 - Test.Cuda.cu_ragged_utils_test (Not Run) 19 - Test.Cuda.cu_rand_test (Not Run) 20 - Test.Cuda.cu_rm_epsilon_test (Not Run) 21 - Test.Cuda.cu_tensor_ops_test (Not Run) 22 - Test.Cuda.cu_tensor_test (Not Run) 23 - Test.Cuda.cu_thread_pool_test (Not Run) 24 - Test.Cuda.cu_top_sort_test (Not Run) 25 - Test.Cuda.cu_utils_test (Not Run) 26 - Test.arcsort_test (Not Run) 27 - Test.array_test (Not Run) 28 - Test.aux_labels_test (Not Run) 29 - Test.connect_test (Not Run) 30 - Test.determinize_test (Not Run) 31 - Test.fsa_equivalent_test (Not Run) 32 - Test.fsa_renderer_test (Not Run) 33 - Test.fsa_test (Not Run) 34 - Test.fsa_util_test (Not Run) 35 - Test.intersect_test (Not Run) 36 - Test.properties_test (Not Run) 37 - Test.rmepsilon_test (Not Run) 38 - Test.topsort_test (Not Run) 39 - Test.weights_test (Not Run) 80 - host_arcsort_test_py (Failed) 81 - host_array_test_py (Failed) 82 - host_aux_labels_test_py (Failed) 83 - host_connect_test_py (Failed) 84 - host_determinize_test_py (Failed) 85 - host_fsa_equivalent_test_py (Failed) 86 - host_fsa_test_py (Failed) 87 - host_intersect_test_py (Failed) 88 - host_properties_test_py (Failed) 89 - host_rmepsilon_test_py (Failed) 90 - host_topsort_test_py (Failed) 91 - host_weights_test_py (Failed) Errors while running CTest

I re-build k2 and the information are as follows:

cmake -DCMAKE_BUILD_TYPE=Release .. -- The CUDA compiler identification is NVIDIA 10.2.89 -- The CXX compiler identification is GNU 8.3.1 -- Check for working CUDA compiler: /usr/local/cuda/bin/nvcc -- Check for working CUDA compiler: /usr/local/cuda/bin/nvcc -- works -- Detecting CUDA compiler ABI info -- Detecting CUDA compiler ABI info - done -- Check for working CXX compiler: /usr/bin/c++ -- Check for working CXX compiler: /usr/bin/c++ -- works -- Detecting CXX compiler ABI info -- Detecting CXX compiler ABI info - done -- Detecting CXX compile features -- Detecting CXX compile features - done -- K2_OS: CentOS Linux release 8.2.2004 (Core) -- Found Git: /usr/bin/git (found version "2.18.2") -- Looking for C++ include cxxabi.h -- Looking for C++ include cxxabi.h - found -- Looking for C++ include execinfo.h -- Looking for C++ include execinfo.h - found -- Performing Test K2_COMPILER_SUPPORTS_CXX14 -- Performing Test K2_COMPILER_SUPPORTS_CXX14 - Success -- C++ Standard version: 14 -- Autodetected CUDA architecture(s): 6.1 6.1 6.1 6.1 6.1 6.1 6.1 6.1 -- K2_COMPUTE_ARCH_FLAGS: -gencode;arch=compute_61,code=sm_61 CMake Warning at CMakeLists.txt:123 (message): arch 62/72 are not supported for now

-- Skipping arch 35 -- Skipping arch 50 -- Skipping arch 60 -- Adding arch 61 -- Skipping arch 70 -- Skipping arch 75 -- K2_COMPUTE_ARCHS: 61 -- Found Valgrind: /usr/bin
-- Found Valgrind: /usr/bin/valgrind -- To check memory, run ctest -R <NAME> -D ExperimentalMemCheck -- Downloading pybind11 -- pybind11 is downloaded to /k2/build/_deps/pybind11-src -- pybind11 v2.6.0 -- Found PythonInterp: /anaconda3/bin/python (found version "3.8.5") -- Found PythonLibs: /anaconda3/lib/libpython3.8.so -- Performing Test HAS_FLTO -- Performing Test HAS_FLTO - Success -- Python executable: /anaconda3/bin/python -- Looking for C++ include pthread.h -- Looking for C++ include pthread.h - found -- Looking for pthread_create -- Looking for pthread_create - not found -- Looking for pthread_create in pthreads -- Looking for pthread_create in pthreads - not found -- Looking for pthread_create in pthread -- Looking for pthread_create in pthread - found -- Found Threads: TRUE
-- Found CUDA: /usr/local/cuda (found version "10.2") -- Caffe2: CUDA detected: 10.2 -- Caffe2: CUDA nvcc is: /usr/local/cuda/bin/nvcc -- Caffe2: CUDA toolkit directory: /usr/local/cuda -- Caffe2: Header version is: 10.2 -- Found CUDNN: /usr/local/cuda/lib64/libcudnn.so
-- Found cuDNN: v8.0.5 (include: /usr/local/cuda/include, library: /usr/local/cuda/lib64/libcudnn.so) -- /usr/local/cuda/lib64/libnvrtc.so shorthash is 08c4863f -- Autodetected CUDA architecture(s): 6.1 6.1 6.1 6.1 6.1 6.1 6.1 6.1 -- Added CUDA NVCC flags for: -gencode;arch=compute_61,code=sm_61 -- Found Torch: /anaconda3/lib/python3.8/site-packages/torch/lib/libtorch.so
-- PyTorch version: 1.8.0 -- PyTorch cuda version: 10.2 -- Downloading cub -- cub is downloaded to /k2/build/_deps/cub-src -- Downloading moderngpu -- moderngpu is downloaded to /k2/build/_deps/moderngpu-src -- Downloading googletest -- googletest is downloaded to /k2/build/_deps/googletest-src -- googletest's binary dir is /k2/build/_deps/googletest-build -- The C compiler identification is GNU 8.3.1 -- Check for working C compiler: /usr/bin/cc -- Check for working C compiler: /usr/bin/cc -- works -- Detecting C compiler ABI info -- Detecting C compiler ABI info - done -- Detecting C compile features -- Detecting C compile features - done -- Generated /k2/build/k2/csrc/version.h -- Configuring done -- Generating done -- Build files have been written to: /k2/build

csukuangfj commented 3 years ago

Did you run

make -j

before running

ctest?

The log is from the configure phase, not from the build phase.

csukuangfj commented 3 years ago

@zhichaowang

Could you try https://github.com/k2-fsa/snowfall/pull/174? I think it helps.

zhichaowang commented 3 years ago

@csukuangfj Ok,I'll try it.

zhichaowang commented 3 years ago

@csukuangfj I re-installed lhotse, k2 and snowfall, the convergence porblem was solved.

The following are the results of w/o ali_model: image

During decoding, I turned off the LM rescoring.

danpovey commented 3 years ago

Thanks! I suppose it's unclear from this whether the alignment is helpful in terms of WER. We can turn it off when we don't have problems with convergence, but for now let's keep it in the code because it makes it easier to play with new models and not worry so much about will it converge.

On Sun, Apr 25, 2021 at 2:06 PM ZhichaoWang @.***> wrote:

@csukuangfj https://github.com/csukuangfj I re-installed lhotse, k2 and snowfall, the convergence porblem was solved.

The following are the results of w/o ali_model: [image: image] https://user-images.githubusercontent.com/8521283/115982617-33c08380-a5cf-11eb-8982-8c1626f5877b.png

During decoding, I turned off the LM rescoring.

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/k2-fsa/snowfall/issues/170#issuecomment-826264280, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAZFLO6OX2IW6E6YU4EIYEDTKOWPTANCNFSM43KAVT4A .

danpovey commented 3 years ago

And please try the LM rescoring code.

On Sun, Apr 25, 2021 at 2:38 PM Daniel Povey @.***> wrote:

Thanks! I suppose it's unclear from this whether the alignment is helpful in terms of WER. We can turn it off when we don't have problems with convergence, but for now let's keep it in the code because it makes it easier to play with new models and not worry so much about will it converge.

On Sun, Apr 25, 2021 at 2:06 PM ZhichaoWang @.***> wrote:

@csukuangfj https://github.com/csukuangfj I re-installed lhotse, k2 and snowfall, the convergence porblem was solved.

The following are the results of w/o ali_model: [image: image] https://user-images.githubusercontent.com/8521283/115982617-33c08380-a5cf-11eb-8982-8c1626f5877b.png

During decoding, I turned off the LM rescoring.

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/k2-fsa/snowfall/issues/170#issuecomment-826264280, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAZFLO6OX2IW6E6YU4EIYEDTKOWPTANCNFSM43KAVT4A .

zhichaowang commented 3 years ago

@danpovey Add the results of LM rescoring: image

The LM rescoring related configuration: --use-lm-rescoring=True --num-path=100 --max-duration=200 --output-beam-size=20

I didn't use the whole lattice for LM rescoring because of the OOM problem, even when I reduced max-duration to 10 and output-beam-size to 3. The GPU used for decoding is Tesla P40 with 22GB memory.

csukuangfj commented 3 years ago

even when I reduced max-duration to 10 and output-beam-size to 3

Could you try the latest snowfall? I just fixed the bug about output-beam-size in https://github.com/k2-fsa/snowfall/pull/178. Sorry for the inconvenience. The command line option --output-beam-size was never used.

csukuangfj commented 3 years ago

I didn't use the whole lattice for LM rescoring

Rescoring with the whole lattice should give you a WER less than 6% on test-clean.

zhichaowang commented 3 years ago

even when I reduced max-duration to 10 and output-beam-size to 3

Could you try the latest snowfall? I just fixed the bug about output-beam-size in #178. Sorry for the inconvenience. The command line option --output-beam-size was never used.

I tried the latest snowfall, the output-beam-size option worked. The WER on test-clean increase from 6.31% to 6.47% when -num-path=100 --output-beam-size=20. The OOM problem still exists when rescoring with the whole lattice, with --max-duration=10 --output-beam-size=1.

csukuangfj commented 3 years ago

The OOM problem still exists

Perhaps your GPU has a very limited RAM. How about decreasing the output_beam_size further? It can be a floating-point number, e.g, 0.5, 0.1, ...

zhichaowang commented 3 years ago

The OOM problem still exists

Perhaps your GPU has a very limited RAM. How about decreasing the output_beam_size further? It can be a floating-point number, e.g, 0.5, 0.1, ...

I found the "--output-beam-size" option has nothing to do with the OOM problem. The OOM problem occured before searching process, following is the log. The GPU capacity is 22GB.

CUDA_VISIBLE_DEVICES=7 ./mmi_att_transformer_decode.py --epoch=10 --avg=5 --use-lm-rescoring=True --num-path=-1 --max-duration=10 --output-beam-size=0.01 2021-04-26 11:35:34,393 DEBUG [mmi_att_transformer_decode.py:246] output_beam_size: 0.01 2021-04-26 11:35:34,640 DEBUG [mmi_att_transformer_decode.py:259] About to load model 2021-04-26 11:35:34,847 INFO [common.py:111] average over checkpoints ['exp-conformer-noam-mmi-att-musan-sa/epoch-5.pt', 'exp-conformer-noam-mmi-att-musan-sa/epoch-6.pt', 'exp-conformer-noam-mmi-att-musan-sa/epoch-7.pt', 'exp-conformer-noam-mmi-att-musan-sa/epoch-8.pt', 'exp-conformer-noam-mmi-att-musan-sa/epoch-9.pt'] 2021-04-26 11:35:38,789 DEBUG [mmi_att_transformer_decode.py:322] Loading pre-compiled HLG 2021-04-26 11:35:43,492 INFO [mmi_att_transformer_decode.py:328] Rescoring with the whole lattice 2021-04-26 11:35:43,516 DEBUG [mmi_att_transformer_decode.py:347] Loading pre-compiled G_4_gram.pt Traceback (most recent call last): File "./mmi_att_transformer_decode.py", line 402, in main() File "./mmi_att_transformer_decode.py", line 355, in main G = k2.arc_sort(G) File "/k2/k2/python/k2/fsa_algo.py", line 492, in arc_sort ragged_arc, arc_map = _k2.arc_sort(fsa.arcs, need_arc_map=need_arc_map) RuntimeError: CUDA out of memory. Tried to allocate 1.18 GiB (GPU 0; 22.38 GiB total capacity; 20.76 GiB already allocated; 301.56 MiB free; 21.33 GiB reserved in total by PyTorch) Exception raised from malloc at /pytorch/c10/cuda/CUDACachingAllocator.cpp:288 (most recent call first): frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x7f33d69362f2 in /anaconda3/lib/python3.8/site-packages/torch/lib/libc10.so) frame #1: + 0x1bc01 (0x7f33d6b94c01 in /anaconda3/lib/python3.8/site-packages/torch/lib/libc10_cuda.so) frame #2: + 0x1c924 (0x7f33d6b95924 in /anaconda3/lib/python3.8/site-packages/torch/lib/libc10_cuda.so) frame #3: + 0x1cf43 (0x7f33d6b95f43 in /anaconda3/lib/python3.8/site-packages/torch/lib/libc10_cuda.so) frame #4: k2::PytorchCudaContext::Allocate(unsigned long, void*) + 0x24 (0x7f336de31c94 in /k2/build/lib/libk2context.so) frame #5: (anonymous namespace)::ModernGpuAllocator::alloc(unsigned long, mgpu::memory_space_t) + 0x1f (0x7f336dd1f15f in /k2/build/lib/libk2context.so) frame #6: mgpu::detail::segsort_t<mgpu::empty_t, k2::Arc, int, k2::LessThan >::segsort_t(k2::Arc, int, int, k2::LessThan, mgpu::context_t&) + 0x27b (0x7f336dc46d8b in /k2/build/lib/libk2context.so) frame #7: void k2::SortSublists<k2::Arc, k2::LessThan >(k2::Ragged, k2::Array1) + 0xe4 (0x7f336dc48164 in /k2/build/lib/libk2context.so) frame #8: k2::ArcSort(k2::Ragged&, k2::Ragged, k2::Array1*) + 0x19f (0x7f336dc3246f in /k2/build/lib/libk2context.so) frame #9: + 0x58ca9 (0x7f336ea57ca9 in /k2/build/lib/_k2.cpython-38-x86_64-linux-gnu.so) frame #10: + 0x24697 (0x7f336ea23697 in /k2/build/lib/_k2.cpython-38-x86_64-linux-gnu.so)

frame #28: __libc_start_main + 0xf3 (0x7f3425f5e6a3 in /lib64/libc.so.6)
csukuangfj commented 3 years ago

The OOM problem occured before searching process, following is the log.

One workaround is to use CPU to do k2.arc_sort:

G = k2.arc_sort(G.to('cpu')).to(device)

But I am not sure whether you will encounter CUDA OOM in the later stages.