numba / numba

NumPy aware dynamic Python compiler using LLVM
https://numba.pydata.org/
BSD 2-Clause "Simplified" License
9.9k stars 1.12k forks source link

iter(Dict()) shows odd behavior + iter(List()) breaks #7427

Open DannyWeitekamp opened 3 years ago

DannyWeitekamp commented 3 years ago

Using Numba 0.54.0 They following minimal reproduction shows unusual behavior when callig iter() on an instance of Dict(). It seems like iter() instances are replacing each other on instantiation.

from numba import njit, i8
from numba.typed import Dict, List

@njit(cache=True)
def make_dicts(n=3,m=3):
    l = List()
    for j in range(1,n+1):
        d = Dict.empty(i8,i8)
        for i in range(1,m+1):
            d[i*j] = i+10
        l.append(d)
    return l

@njit(cache=True)
def test_iters(iterables):
    iters = [iter(iterable) for iterable in iterables]
    out = [next(it) for it in iters]
    print(":", out)

print(make_dicts()) # [{1: 11, 2: 12, 3: 13}, {2: 11, 4: 12, 6: 13}, {3: 11, 6: 12, 9: 13}]
test_iters.py_func(make_dicts()) # Python : [1,2,3]
test_iters(make_dicts()) # Numba : [3,6,9] <- wrong seems like reusing same iterator

And for good measure this also fails

@njit(cache=True)
def make_lists(n=3,m=3):
    l = List()
    for j in range(1,n+1):
        l2 = List.empty_list(i8)
        for i in range(1,m+1):
            l2.append(i*j)
        l.append(l2)
    return l

print(make_lists()) # [[1, 2, 3], [2, 4, 6], [3, 6, 9]]

test_iters.py_func(make_lists()) # Python : [1,2,3]
test_iters(make_lists()) # Numba : Error
esc commented 3 years ago

@DannyWeitekamp thank you for reporting this. I can reproduce. It would probably make sense to invest the time to reduce the size of the reproducer. It's hard to work out what is going on exactly.

esc commented 3 years ago

I have modified this to be like:

from numba import njit
from numba.typed import List, Dict

# this is the target
its = [{1: 11, 2: 12, 3: 13},
       {4: 14, 5: 15, 6: 16},
       {7: 17, 8: 18, 9: 19}]

# convert to typed containers
typed_its = List()
for d in its:
    nd = Dict()
    for k,v in d.items():
        nd[k] = v
    typed_its.append(nd)

@njit(cache=False)
def test_iters(iterables):
    iters = [iter(iterable) for iterable in iterables]
    out = [next(it) for it in iters]
    print(":", out)

test_iters.py_func(its) # Python : [1,4,7]
test_iters(typed_its) # Numba : [7,8,9] <- wrong

As this makes it a bit clearer what may be going on. It seems like the next(it) is iterating through the last/final dictionary in the list. The Python code shows what it should do. I thought initially, it might be an issue with how the dict keys are ordered, but the code doesn't seem to iterate over the first dictionaries at all.

esc commented 3 years ago

I am labelling this as a potential bug since I can reproduce it, but the root cause must still be triaged.

stuartarchibald commented 3 years ago

I think this could be a smaller reproducer:

from numba import njit

@njit
def test_iters():
    iterables = [{1: 11}, {7: 17}]
    iters = [iter(iterable) for iterable in iterables]
    it = iters[0]
    return next(it)

print(test_iters.py_func()) #1
print(test_iters()) #7

given the contents, I'd guess it's more likely a problem with iterators than containers.

DannyWeitekamp commented 3 years ago

This works the same in a python and njitted function:

from numba import njit

@njit(cache=True)
def test_iters_no_list():
    it0 = iter({1: 11, 7: 17})
    it1 = iter({2: 12, 5: 15})
    print(next(it0), next(it1))
test_iters_no_list() # 1 2
test_iters_no_list.py_func() # 1 2

My guess is 1) it has something to do with the loop instantiating all of the iterators in the same memory space because they share an intermediate variable or 2) adding it to the list is causing some issue:

It seems to work fine when adding things one at a time

@njit(cache=True)
def test_iters_unrolled_add_to_list():
    l = []
    it0 = iter({1: 11, 7: 17})
    it1 = iter({2: 12, 5: 15})
    l.append(it0)
    l.append(it1)

    print(next(l[0]), next(l[1]))

test_iters_unrolled_add_to_list() # 1 2
test_iters_unrolled_add_to_list.py_func() # 1 2

But it breaks when the iterators are instantiated inside a loop

@njit(cache=True)
def test_iters_add_to_list():
    l = []
    for i in range(2):
        l.append(iter({1+i: 11+i, 7+i: 17+i}))

    print(next(l[0]), next(l[1]))

test_iters_add_to_list() # 2 8
test_iters_add_to_list.py_func() # 1 2

Poking through the code one thing that struck me as odd is that the model for the DictIteratorType doesn't have a meminfo that points to an underlying payload, so I'm wondering if this means that when a new iterator is instantiated the variable that holds it temporarily holds it as a value instead of a reference, so all of the previous iterators are getting overwritten by the new iterators.

Another detail is that it seems this code doesn't work if [] is replaced with List().

github-actions[bot] commented 3 years ago

This issue is marked as stale as it has had no activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with any updates and confirm that this issue still needs to be addressed.

gmarkall commented 2 years ago

Removing stale label, removing needtriage so it doesn't go stale again.