multithreaded apply should only be used internally as there is a non-negligeable set of operations that can lead to deadlocks.
As of now, it is not made to speed up the tensordict operations (reduce TD overhead) but speed up the execution of any function passed to apply that can benefit from being executed asynchronously over the leaves.
Some benchmarks:
import torch
from tensordict import TensorDict
from torch.utils.benchmark import Timer
# torch.set_num_threads(16)
import functools
d = {}
sub = d
for _ in range(20):
for i in range(50):
sub[str(i)] = torch.rand(2048, 2048)
sub["nested"] = {}
sub = sub["nested"]
td = TensorDict(d, batch_size=[])
assert (td._multithread_apply_nest(lambda x: x, num_threads=16) == td).all()
func = lambda x: x.pin_memory()
print(Timer("td._fast_apply(func)", globals=globals()).adaptive_autorange())
print(Timer("td._multithread_apply_nest(func, checked=True, num_threads=4)", globals=globals()).adaptive_autorange())
print(Timer("td._multithread_apply_nest(func, checked=True, num_threads=8)", globals=globals()).adaptive_autorange())
print(Timer("td._multithread_apply_nest(func, checked=True, num_threads=16)", globals=globals()).adaptive_autorange())
Results:
td._fast_apply(func)
Median: 1.74 s
IQR: 0.02 s (1.73 to 1.75)
4 measurements, 1 runs per measurement, 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7fc456a92530>
td._multithread_apply_nest(func, checked=True, num_threads=4)
Median: 498.21 ms
IQR: 80.43 ms (455.67 to 536.10)
20 measurements, 1 runs per measurement, 1 thread
WARNING: Interquartile range is 16.1% of the median measurement.
This could indicate system fluctuation.
<torch.utils.benchmark.utils.common.Measurement object at 0x7fc3537e44c0>
td._multithread_apply_nest(func, checked=True, num_threads=8)
Median: 303.86 ms
IQR: 24.10 ms (295.45 to 319.54)
4 measurements, 1 runs per measurement, 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7fc35377da80>
td._multithread_apply_nest(func, checked=True, num_threads=16)
Median: 220.13 ms
IQR: 13.43 ms (214.42 to 227.85)
4 measurements, 1 runs per measurement, 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7fc35377fc10>
td._fast_apply(func).to('cuda')
Median: 1.73 s
IQR: 0.02 s (1.72 to 1.73)
4 measurements, 1 runs per measurement, 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7fc3541dbe50>
td._multithread_apply_nest(func, checked=True, num_threads=4).to('cuda')
Median: 864.55 ms
IQR: 80.37 ms (817.05 to 897.42)
9 measurements, 1 runs per measurement, 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7fc35413cc10>
td._multithread_apply_nest(func, checked=True, num_threads=8).to('cuda')
Median: 680.82 ms
IQR: 19.13 ms (679.74 to 698.87)
5 measurements, 1 runs per measurement, 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7fc3541a4400>
td._multithread_apply_nest(func, checked=True, num_threads=16).to('cuda')
Median: 608.28 ms
IQR: 12.52 ms (603.43 to 615.95)
4 measurements, 1 runs per measurement, 1 thread
multithreaded apply should only be used internally as there is a non-negligeable set of operations that can lead to deadlocks. As of now, it is not made to speed up the tensordict operations (reduce TD overhead) but speed up the execution of any function passed to
apply
that can benefit from being executed asynchronously over the leaves.Some benchmarks:
Results:
Results