google-research / albert

ALBERT: A Lite BERT for Self-supervised Learning of Language Representations
Apache License 2.0
3.24k stars 569 forks source link

Uneven sampling if number of tfrecords > cycle_length #177

Open illuminascent opened 4 years ago

illuminascent commented 4 years ago

The input function defined in run_pretraining.py will make a repeating random sequence of tfrecord filenames and open limited amount of handles according to the sequence.

There is problem when number of files is greater than max handles, because a new set of handles will not be opened unless a certain amount of samples were read, this causes severe selection bias, and causes the loss to suddenly jump up or down during pretraining.

This behavior can be confirmed using a toy dataset. 10000 integers made using np.arange() were evenly divided into 10 parts and then shuffled within each part before being written into 10 separate tfrecords. Here's the value of each new sample drawn from the dataset at each training step(step in x axis and value in y). parallel_interleave_original

I would suggest making the cycle_length to always match len(file_lists) and also move d=d.repeat() after the parallel interleave mapping. This solved the problem in my case at least.

Here's how it looks afterwards. parallel_interleave_cycle_extended_repetition_loop_changed

Code to reproduce above results:

import os
import tensorflow as tf
import numpy as np
import collections
import glob
import re
import matplotlib.pyplot as plt
import seaborn as sns
from create_pretraining_data import create_int_feature

tf.enable_eager_execution()

def prep(count=5):
    filenames = list(glob.glob("test*.tfrecord"))
    [os.remove(f) for f in filenames]
    for i in range(count):
        data = np.arange(i * 1000, (i + 1) * 1000)
        np.random.shuffle(data)
        writer = tf.python_io.TFRecordWriter("test%s.tfrecord" % i)
        for x in data:
            features = collections.OrderedDict()
            features["x"] = create_int_feature([x])
            tf_example = tf.train.Example(features=tf.train.Features(feature=features))
            writer.write(tf_example.SerializeToString())

def sample():
    filenames = list(glob.glob("test*.tfrecord"))
    filecount = len(filenames)
    filenames = tf.data.Dataset.from_tensor_slices(tf.constant(filenames))
    filenames = filenames.shuffle(buffer_size=filecount)
    d = filenames.apply(
        tf.contrib.data.parallel_interleave(
            tf.data.TFRecordDataset,
            sloppy=True,
            cycle_length=filecount))
    d = d.repeat()
    # only when all files were read simultaneously will there be real randomness
    d = d.shuffle(buffer_size=100)

    samples = list()

    for example in d:
        message = str(tf.train.Example.FromString(example.numpy())).replace("\n", "")
        samples.append(int(re.sub(r".*value: (\d+).*", "\\1", message)))
        if len(samples) > 40000: break

    f = plt.figure()
    f.canvas.set_window_title('')
    g = sns.scatterplot(range(len(samples)), samples, color="#6a20e3", s=5)
    g.set_title("samples")
    plt.grid(which="both", ls="-")
    plt.show()

prep(10)
sample()