I don't think batched function is doing what we want it to be doing. What is the intended output from it?
Test case:
import itertools
from typing import Iterator
def batched(it: Iterator, n: int):
assert n >= 1
for x in it:
yield itertools.chain((x,), itertools.islice(it, n - 1))
lis = [i for i in range(10)]
n = 5
for b in batched(lis, n):
print(b)
for item in b:
print(item)
Output:
<itertools.chain object at 0x793e86d9d870>
0
0
1
2
3
<itertools.chain object at 0x793e86d9e050>
1
0
1
2
3
<itertools.chain object at 0x793e86d9d540>
2
0
1
2
3
<itertools.chain object at 0x793e86d9cc10>
3
0
1
2
3
<itertools.chain object at 0x793e86d9e740>
4
0
1
2
3
<itertools.chain object at 0x793e86d9d2d0>
5
0
1
2
3
<itertools.chain object at 0x793e86d9f7c0>
6
0
1
2
3
<itertools.chain object at 0x793ebd8bf520>
7
0
1
2
3
<itertools.chain object at 0x793e86c607c0>
8
0
1
2
3
<itertools.chain object at 0x793e86c603d0>
9
0
1
2
3
https://github.com/facebookresearch/dlrm/blob/b631a99cf5ff320272d18a776cda85b5207bdf19/torchrec_dlrm/dlrm_main.py#L368
I don't think batched function is doing what we want it to be doing. What is the intended output from it?
Test case:
Output: <itertools.chain object at 0x793e86d9d870> 0 0 1 2 3 <itertools.chain object at 0x793e86d9e050> 1 0 1 2 3 <itertools.chain object at 0x793e86d9d540> 2 0 1 2 3 <itertools.chain object at 0x793e86d9cc10> 3 0 1 2 3 <itertools.chain object at 0x793e86d9e740> 4 0 1 2 3 <itertools.chain object at 0x793e86d9d2d0> 5 0 1 2 3 <itertools.chain object at 0x793e86d9f7c0> 6 0 1 2 3 <itertools.chain object at 0x793ebd8bf520> 7 0 1 2 3 <itertools.chain object at 0x793e86c607c0> 8 0 1 2 3 <itertools.chain object at 0x793e86c603d0> 9 0 1 2 3