facebookresearch / fairseq2

FAIR Sequence Modeling Toolkit 2
https://facebookresearch.github.io/fairseq2/
MIT License
678 stars 78 forks source link

Unexpected behavior when using repeat() with sample()/round_robin() #604

Closed syleshfb closed 3 months ago

syleshfb commented 3 months ago

Describe the bug: Data pipeline output is cut short when using sample() or round_robin() with repeat()

Describe how to reproduce:

from fairseq2.data import read_sequence, DataPipeline

pipelines = [read_sequence([1,2,3,4]).repeat().and_return(), read_sequence([5,6,7,8]).repeat().and_return()]

pipeline = DataPipeline.sample(pipelines).and_return()

for example in pipeline:
    print(example)

Describe the expected behavior: Expected output: e.g. 1, 5, 2, ... Actual output: 1

Additional Context:

Caused by the way infinite data sources are handled in sample/round_robin:

https://github.com/facebookresearch/fairseq2/blob/7ece73ca39ded191d3774d33b318383e07d753e1/native/src/fairseq2n/data/sample_data_source.cc#L175-L176

https://github.com/facebookresearch/fairseq2/blob/7ece73ca39ded191d3774d33b318383e07d753e1/native/src/fairseq2n/data/round_robin_data_source.cc#L137-L138