praeclarum / webgpu-torch

Tensor computation with WebGPU acceleration
MIT License
576 stars 15 forks source link

Lazy Evaluation of Tensors #7

Closed praeclarum closed 1 year ago

praeclarum commented 1 year ago

This is a PR to fix the performance problem this library has so far suffered from.

Turns out, binding newly created buffers to WebGPU commands is expensive. You really want to reuse buffers not just to keep memory down, but also keep the GPU happy.

I tried a couple tricks:

The heap allocator was the best, but the GC just wasn't freeing buffers fast enough for big computations.

Finally, I decided that the only way to make sure that buffers could be reused efficiently was to first compute the data flow graph and then translate that to SSA form, then from that check on the liveness of buffers. It's not so hard but is only worth it if the compute graph is more than a few nodes.

To enable growing large graphs, I have made tensors lazy be default. They only compute their value on demand - when toArrayAsync() is called or if the storage property is accessed. This is a big breaking change from how PyTorch works, but the performance gains are worth it.

Eager evaluation benchmark

32 GB of memory

BenchmarkTime (ms)Intel(R) Xeon(R) W-2150B CPU @ 3.00GHzNVIDIA GeForce RTX 3090Error
unary 1d(1, 'neg')0.1490.0010.003
unary 1d(1, 'sigmoid')0.1460.0010.003
unary 1d(729, 'neg')0.1550.0010.003
unary 1d(729, 'sigmoid')0.1470.0020.003
unary 1d(2187, 'neg')0.1640.0020.003
unary 1d(2187, 'sigmoid')0.1530.0040.003
unary 1d(59049, 'neg')0.2040.0130.008
unary 1d(59049, 'sigmoid')0.1700.0580.008
unary 1d(177147, 'neg')0.7730.0730.017
unary 1d(177147, 'sigmoid')0.1680.1770.017
unary 1d(531441, 'neg')2.1230.4090.051
unary 1d(531441, 'sigmoid')1.2640.7320.050
unary 1d(1594323, 'neg')3.6541.4880.166
unary 1d(1594323, 'sigmoid')3.0302.3440.157

Lazy evaluation benchmark

4 GB of memory

BenchmarkTime (ms)Intel(R) Xeon(R) W-2150B CPU @ 3.00GHzNVIDIA GeForce RTX 3090Error
unary 1d(1, 'neg')0.0300.0010.003
unary 1d(1, 'sigmoid')0.0290.0010.003
unary 1d(729, 'neg')0.0290.0010.003
unary 1d(729, 'sigmoid')0.0290.0020.003
unary 1d(2187, 'neg')0.0290.0020.003
unary 1d(2187, 'sigmoid')0.0290.0040.003
unary 1d(59049, 'neg')0.0310.0130.008
unary 1d(59049, 'sigmoid')0.0310.0580.008
unary 1d(177147, 'neg')0.0360.0730.017
unary 1d(177147, 'sigmoid')0.0380.1770.017
unary 1d(531441, 'neg')0.0480.4090.051
unary 1d(531441, 'sigmoid')0.0500.7320.050
unary 1d(1594323, 'neg')0.0891.4880.166
unary 1d(1594323, 'sigmoid')0.0892.3440.157