mila-iqia / fuel

A data pipeline framework for machine learning
MIT License
867 stars 268 forks source link

Add `copy` flag that enables a deep copy of the entire stream pipeline. #362

Open petrbel opened 8 years ago

petrbel commented 8 years ago

The flag introduced in this commit could be used to split a stream into multiple ones and use each branch independently. AbstractDataStream now accepts an optional flag copy which is False by default. When set to True, the DataIterator in the AbstractDataStream.get_epoch_iterator is always provided with a unique deep copy of the data stream.

In addition, Merge now accepts and forwards **kwargs to AbstractDataStream for compatibility reasons.

The following code demonstrates the unexpected behavior which could be avoided by using the copy=True flag. The code splits a stream of two variables into two independent streams. Then the streams are merged back into a single stream again. An expected behavior is that the resulting stream generates the same data as the original stream. However, this is not the case.

import pprint
import numpy as np

from fuel.streams import DataStream
from fuel.datasets import IndexableDataset
from fuel.transformers import FilterSources, Merge
from fuel.schemes import ShuffledScheme, SequentialScheme

def print_stream(data_stream, msg):
    print(msg)
    for d in data_stream.get_epoch_iterator(as_dict=True):
        pprint.pprint(d, width=1)
        print('')
    print('----------------------')

if __name__ == '__main__':
    num_examples = 10

    x = np.tile(np.arange(num_examples), (3, 1)).T
    y = np.tile(np.arange(num_examples) + num_examples, (2, 1)).T

    dataset = IndexableDataset({'x': x, 'y': y})
    stream = DataStream(dataset=dataset, iteration_scheme=SequentialScheme(batch_size=3, examples=dataset.num_examples))
    print_stream(stream, 'After stream creation:')

    stream = FilterSources(stream, ('x', 'y'))
    print_stream(stream, 'After innocent filter:')

    x_only = FilterSources(stream, ('x',))
    y_only = FilterSources(stream, ('y',))
    stream = Merge([x_only, y_only], x_only.sources + y_only.sources)
    print_stream(stream, 'After merge:')
dmitriy-serdyuk commented 8 years ago

It looks like a bug. So, the reason for it is that the get_data is called twice, right?

petrbel commented 8 years ago

@dmitriy-serdyuk exactly, in this case twice. Originally I had multiple following split-merges which caused that get_data was called even more than twice.

dmitriy-serdyuk commented 8 years ago

Is it possible just to copy the request iterator? It should be sufficient.

And can you add a test?