viralogic / py-enumerable

A Python module used for interacting with collections of objects using LINQ syntax
MIT License
187 stars 24 forks source link

Unexpected behavior when using iterator as the input data #22

Closed lutecki closed 4 years ago

lutecki commented 5 years ago

First of all, thank you for providing this great package!

The following code produces output that I didn't expect:

from py_linq import Enumerable as en

def my_iter():
    for i in range(10):
        yield i

data = my_iter()        
a = en(data)

low = a.where(lambda x: x < 5)
high = a.where(lambda x: x >= 5)

for l, h in zip(low, high):
    print("PAIR:", l, h)

this is producing:

PAIR: 0 5

I was expecting this:

PAIR: 0 5
PAIR: 1 6
PAIR: 2 7
PAIR: 3 8
PAIR: 4 9

I believe this has something to do with the caching mechanism in __iter__() and the fact that the same iterator is being iterated in parallel by different consumers. I'm using Python 3.6 and py-linq-1.0.1 (installed via PIP).

viralogic commented 5 years ago

Hey, thanks for the feedback! Glad you like the library.

I will have a look at this issue when I get a spare moment. In the meantime, you may want to have a look at the .zip function and see if that helps. My guess is that you are correct and it is a caching issue, so using the zip function may be a workaround for you, while also allowing you to write all that code in a less lines of code!

https://viralogic.github.io/py-enumerable/zip

query = Enumerable(range(10).where(lambda x: x <= 5).zip(Enumerable(range(10)).where(lambda y: y > 5), lambda result: "PAIR {0} {1}".format(result[0], result[1]))

Should do the trick. Please note I actually haven't tested this code, just off the top of my head.

If you have any issues or questions, don't hesitate to ask.

lutecki commented 5 years ago

Hi,

Probably I can use your suggestion for the zip() example above, but the library itself does nested iteration over data in e. g. join() by calling itertools.product(). This snippet:

from py_linq import Enumerable as en

class Val(object):

    def __init__(self, number, power):
        self.number = number
        self.power = power

    def __str__(self):
        return "VAL {0}: {1}".format(self.number, self.power)

def powers_of_2():
    for i in range(10):
        yield Val(i, 2 ** i)

def powers_of_10():
    for i in range(10):
        yield Val(i, 10 ** i)

en2 = en(list(powers_of_2()))
en10 = en(list(powers_of_10()))

joined = en2.join(en10, lambda x: x.number, lambda x: x.number)

for j in joined:
    print("Joined", j[0], j[1])

gives me:

Joined VAL 0: 1 VAL 0: 1
Joined VAL 1: 2 VAL 1: 10
Joined VAL 2: 4 VAL 2: 100
Joined VAL 3: 8 VAL 3: 1000
Joined VAL 4: 16 VAL 4: 10000
Joined VAL 5: 32 VAL 5: 100000
Joined VAL 6: 64 VAL 6: 1000000
Joined VAL 7: 128 VAL 7: 10000000
Joined VAL 8: 256 VAL 8: 100000000
Joined VAL 9: 512 VAL 9: 1000000000

as expected, but if I change

en2 = en(list(powers_of_2()))
en10 = en(list(powers_of_10()))

to

en2 = en(powers_of_2())
en10 = en(powers_of_10())

the result is empty. I believe there will always be a problem with the caching as long as there is nested iteration and self._data is an iterator and not a collection. I don't thing there is a good way around it, but at least you could try to add some kind of lock when __iter__() starts yielding from an iterator and, if a nested call is done, raise some error. I've tried this and it works for the version above and raises error when no conversion to list is done:

    def __iter__(self):
        if self.itering > 0 and not hasattr(self._data, "__len__"):
            raise ValueError("Nested iteration prohibited.")
        self.itering += 1
        cache = []
        for element in self._data:
            cache.append(element)
            yield element
        self._data = cache
        self.itering -= 1

don't know if this is a good solution but at least this doesn't fail silently. Regards.

viralogic commented 5 years ago

@lutecki

This appears to be an issue with the where method specifically in the Python 3 implementation:

Using your submitted code in Python 2.7.14, the output is:

('PAIR:', 0, 5)
('PAIR:', 1, 6)
('PAIR:', 2, 7)
('PAIR:', 3, 8)
('PAIR:', 4, 9)

while in Python 3 I get the same output as you submitted:

PAIR 0 5

I confirmed this by removing the where function call in Python 3 and then zipping the 2 collections together.

from py_linq import Enumerable
def low_iter():
    for i in range(5):
        yield i

def high_iter():
    for k in range(5):
        yield k + 5

low = Enumerable(low_iter())
high = Enumerable(high_iter())
for l, h in zip(low, high):
    print("PAIR", l, h)

PAIR 0 5
PAIR 1 6
PAIR 2 7
PAIR 3 8
PAIR 4 9

The where function in the Python3 implementation uses filter rather than itertools.ifilter in the Python2 implemenation. itertools.ifilter was dropped and the Py3 built-in filter function is supposed to work the same way as itertools.ifilter in Py2.

Not sure how to solve this problem at this point, but will try some experiments to see.

viralogic commented 4 years ago

This fix will be part of version 1.2