1ytic / warp-rnnt

CUDA-Warp RNN-Transducer
MIT License
211 stars 41 forks source link

Add impl of compact layout input following #26 #39

Closed maxwellzh closed 1 year ago

maxwellzh commented 1 year ago

For how the compact layout works, please take a look at the discussion in #26

I added the impl in core_compact.cu and the pytorch binding in the pytorch_binding/binding.cpp. A test case was added in pytorch_binding/warp_rnnt/test.py. The benchmark script is not updated. (I don't know how to bind it to tensorflow)

You may see some changes to unrelated codes. That's due to the auto-format during my modification. I wish that won't cause you trouble in reviewing.

1ytic commented 1 year ago

I'm trying to benchmark this version. Could you share some use case when it's better? I moved the following packing operations outside the time measurement, but still, it is worse than gather version:

N = xs.size(0)
V = xs.size(-1)
xs = torch.cat([xs[i, :xn[i], :yn[i]+1].reshape(-1, V) for i in range(N)])
ys = torch.cat([ys[i, :yn[i]] for i in range(N)])
maxwellzh commented 1 year ago

compact=True obtains potentially gains from variable lengths in the batch while requiring extra packing operations. Thus, if random_length=False (or N=1), there's no benefit to enable compact=True. Moreover, compact=True could reduce the computation overhead of the feed-forward only joint net.

I upload a new benchmark script, which measures the time and memory usage including both joint net and the loss computation so the difference in joint net overhead could be monitored.

Here're some results on a RTX3090 (24GB memory). H stands for hidden size of the joint net inputs.

H=512 T=150 U=40 V=28

N compact fwd-only gather fwd-only compact fwd-bwd gather fwd-bwd
16 42.613ms, 233.86MB 17.233ms, 431.41MB 203.460ms, 342.48MB 43.280ms, 622.76MB
32 75.173ms, 482.91MB 33.946ms, 863.78MB 373.194ms, 768.52MB 82.399ms, 1246.55MB
64 140.568ms, 950.64MB 61.949ms, 1725.46MB 594.851ms, 1486.58MB 150.704ms, 2490.67MB
128 284.994ms, 2006.94MB 112.423ms, 3452.43MB 1.068s, 2851.25MB 269.747ms, 4983.33MB

H=512 T=150 U=20 V=5000

N compact fwd-only gather fwd-only compact fwd-bwd gather fwd-bwd
16 123.887ms, 1742.10MB 188.205ms, 3145.77MB 361.767ms, 1800.86MB 467.143ms, 3259.16MB
32 234.453ms, 3392.50MB 386.951ms, 6276.81MB 708.439ms, 3769.76MB 935.766ms, 6494.32MB
64 486.062ms, 7015.84MB 773.315ms, 12544.78MB 1.435s, 7712.72MB 1.864s, 12969.04MB
128 988.776ms, 14599.57MB OOM 3.967s, 15023.55MB OOM

H=256 T=150 U=40 V=28

N compact fwd-bwd gather fwd-bwd
16 225.052ms, 205.61MB 33.830ms, 317.29 MB
32 399.421ms, 364.90MB 51.210ms, 635.86MB
64 680.739ms, 736.37MB 93.312ms, 1268.96MB
128 1.156s, 1493.12MB 159.072ms, 2537.86MB

H=256 T=150 U=20 V=5000

N compact fwd-bwd gather fwd-bwd
16 298.333ms, 1739.04MB 302.025ms, 3144.54MB
32 491.745ms, 3483.68MB 601.328ms, 6271.13MB
64 995.477ms, 7406.63MB 1.187s, 12535.87MB
128 2.460s, 14027.48MB OOM

Generally, we care more about the fwd-bwd measurement. In summary, compact mode always consume considerably less memory than the gather mode. However, in terms of execution time, if V is small (which means the benefit of compact from joint net would be less) and N is large (the overhead of compact packing increases), gather mode is significantly faster than compact. If V is >=~1K, compact mode is a bit faster than gather mode, let alone the memory usage is much less.

I also ran some benchmarks on real ASR datasets.

aishell-1 dev:

mean(T)=113, mean(U)=14, V=4232 (Chinese characters)
N compact fwd-bwd gather fwd-bwd
16 325.944ms, 1379.11MB 561.944ms, 3860.13MB
32 603.264ms, 3136.27MB 934.786ms, 6416.03MB
64 1.355s, 6845.54MB 2.222s, 15801.37MB
128 3.768s, 13292.62MB OOM

librispeech dev:

mean(T)=169, mean(U)=31, V=1024 (English BPE)
N compact fwd-bwd gather fwd-bwd
16 437.417ms, 2570.03MB 1.408s, 11807.10MB
32 504.782ms, 2139.11MB 985.302ms, 8123.03MB
64 1.799s, 8893.65MB OOM
128 5.173s, 17869.89MB OOM

In these test cases on real datasets, since my vocab size is ~1k, compact mode performs quite impressive compared to gather mode.

More detailed benchmark logs: benchmark_results.tar.gz

1ytic commented 1 year ago

@maxwellzh thank you! I made a new package release 0.7.0.