HazyResearch / flyingsquid

More interactive weak supervision with FlyingSquid
Apache License 2.0
313 stars 21 forks source link

Prediction Time #6

Closed dmitra79 closed 4 years ago

dmitra79 commented 4 years ago

Hello,

This a question rather than an issue. I tried applying flyingsquid to a dataset with 65M instances and 3 weak labels, using the model structures from tutorials. The single node model one took ~2.5s to train, and the sequential (3 node) took ~8s. However, when I tried to get predictions: preds = label_model.predict(L_train) it was running for a long time (~20minutes) without completing. Does this behavior make sense? What could cause it?

Thanks!

DanFu09 commented 4 years ago

There's two small places in the inference code that haven't quite been fully optimized for large arrays (https://github.com/HazyResearch/flyingsquid/blob/master/flyingsquid/label_model.py#L1273 and https://github.com/HazyResearch/flyingsquid/blob/master/flyingsquid/label_model.py#L1321 if you're curious). The first one hits Numpy's scalability constraints, and the second is a loop comprehension -- both will slow down inference on large datasets.

Luckily, inference is embarrassingly parallel, so a simple wrapper around a parallel for loop should get your inference times down to something more reasonable (I'm running this on a server with 56 CPU's):

from multiprocessing import Pool

chunk_size = 100000
def parallel_preds(x):
    return label_model.predict(L_train[x * chunk_size:(x + 1) * chunk_size])

p = Pool(56)
all_preds = p.map(parallel_preds, range(int(L_train.shape[0] / chunk_size)))

preds = np.concatenate(all_preds)
dmitra79 commented 4 years ago

Thank you! This worked great - returned results in 2 minutes (with 16 processes)