snorkel-team / snorkel

A system for quickly generating training data with weak supervision
https://snorkel.org
Apache License 2.0
5.81k stars 857 forks source link

Capped-iteration LabelModel permutation sampling. #1488

Closed dataframing closed 5 years ago

dataframing commented 5 years ago

Description of proposed changes

This PR addresses #1486, where we found that a one unit increase in target cardinality (k) results in an exponential (~8-10x) increase in runtime, due to attempting to evaluate all k! permutations of our labeling functions while searching for an ideal μ in LabelModel._break_col_permutation_symmetry. This makes using LabelModel a non-starter for multi-class classification tasks with a cardinality greater than ~8.

To address this issue, we instead implement a randomized search policy over all k! permutations – parameterized only by n_iter – which is the maximum number of distinct permutations we'd like to consider before early-stopping.

Considerations

In general there are a few questions we should be aware of:

Related issue(s)

Fixes #1486.

Test plan

I've added a test test_label_model_large_multiclass under TestLabelModelAdvanced, which is marked as a complex test suite. It attempts to run a LabelModel with cardinality = 20, which would take a long time without a capped number of iterations. Instead, I've capped the number of iterations to 100, so it finishes pretty quickly (a few seconds, modulo local computing resources).

To test, run:

tox            # General test suite.
tox -e complex # Complex test suite.

Anecdotally, I've also run the example from #1486 with 10 <= N_CLASSES <= 13 and the runtime is constant with respect to the provided n_iter. At n_iter: int = 500_000, here are examples of runs on my local machine:

Sampling from all possible permutations...: 100%|██████████| 500000/500000 [00:43<00:00, 11510.61it/s]
Searching for ideal μ amongst candidates...: 100%|██████████| 500000/500000 [03:13<00:00, 2586.52it/s]
Fitting label model with 10 classes took 238.014 seconds.
-------------------
Sampling from all possible permutations...: 100%|██████████| 500000/500000 [00:40<00:00, 12370.89it/s]
Searching for ideal μ amongst candidates...: 100%|██████████| 500000/500000 [03:22<00:00, 2473.01it/s]
Fitting label model with 11 classes took 243.921 seconds.
-------------------
Sampling from all possible permutations...: 100%|██████████| 500000/500000 [00:40<00:00, 12325.28it/s]
Searching for ideal μ amongst candidates...: 100%|██████████| 500000/500000 [03:39<00:00, 2276.67it/s]
Fitting label model with 12 classes took 261.595 seconds.
-------------------
Sampling from all possible permutations...: 100%|██████████| 500000/500000 [00:39<00:00, 12528.03it/s]
Searching for ideal μ amongst candidates...: 100%|██████████| 500000/500000 [04:00<00:00, 2075.09it/s]
Fitting label model with 13 classes took 282.288 seconds.
-------------------

Note: the runtime still increases a bit with every increase in cardinality, but that's because each sampled permutation is one item larger, which adds up over n_iter samples. Anecdotally, as k increases, the time it takes to sample n_iter unique permutations decreases (fewer and fewer opportunities for duplicate sampling), but the core search for an ideal μ increases in runtime (larger matrices being generated and multiplied).

Checklist

Need help on these? Just ask!

codecov[bot] commented 5 years ago

Codecov Report

Merging #1488 into master will decrease coverage by 0.17%. The diff coverage is 80.95%.

@@            Coverage Diff             @@
##           master    #1488      +/-   ##
==========================================
- Coverage   97.65%   97.48%   -0.18%     
==========================================
  Files          56       56              
  Lines        2093     2109      +16     
  Branches      336      340       +4     
==========================================
+ Hits         2044     2056      +12     
- Misses         22       25       +3     
- Partials       27       28       +1
Impacted Files Coverage Δ
snorkel/labeling/model/label_model.py 94.81% <80.95%> (-1.02%) :arrow_down:
dataframing commented 5 years ago

@henryre I can't seem to tag a reviewer for this PR, but let me know if you have any questions/comments/concerns. Also attempting to tag @plison for an extra pair of eyes.

Edit: it seems like codecov is failing because I'm not properly testing the logging statements for when max(scores) == -1 and n_iter == max_n_iter. I've been looking for another test on the label model that induces a case where max(scores) == -1 and can't seem to definitively find one — @henryre would you be able to point me in the right direction? Does such a test exist? Alternatively, if it's not within scope for this PR, we can just remove the logging.

ajratner commented 5 years ago

@dataframing to start: thank you so much for pointing out this issue in such a detailed and thorough way, and for this PR! We owe contributors like you big time for this kind of help and support :)

Also, apologies for some lack of coordination on our end. We have a planned update that will (at least partially) address this issue in a more principled way than random sampling, i.e. by relying on some additional default assumptions applied to the underlying mathematical model that yields valid permutations of the learned weights. We will be trying to get to this soon, and so I'd probably hold off on further development here for now. I'll then circle back and see if there's a way to merge our solution with some of the work done here. Thanks!

dataframing commented 5 years ago

@ajratner not a problem! It was a fun exercise. I'm looking forward to the update, and I'll be on the lookout for other ways to contribute. Thanks to the team for their hard work on this! 👍

henryre commented 5 years ago

@dataframing @ajratner miscommunication was on me. Thanks for the effort on this!

plison commented 5 years ago

I've been experimenting with @dataframing 's code, as I am working on a classification problem with 18 distinct classes (so enumerating all possible permutations is obviously not feasible). Just wanted to say that the code runs, but I noticed a bug at the end of the (proposed) _break_col_permutation_symmetry method, namely that the mu parameter is modified (line 866-868 of label_model.py) even when a valid μ value isn't found. This is a bad idea, as in that case, the mu parameter will transformed into a random permutation that is much worse than the initial mu value. I think a return should be added on line 865 to avoid changing the mu parameter when no solution is found.

(just wanted to mention this in case someome else tries that code and stumbles upon the same problem).

dataframing commented 5 years ago

I made a comment a moment ago, but I think it's wrong. Yes, I think if we're only sampling a fraction of the space of all permutations, we should make sure to not update the mu instance unless we have a valid (non -1) entry in scores.

eggie5 commented 5 years ago

@ajratner I'm on the lookout for this also, as I just moved from 4 classes to 12 and I noticed the LabelModel does not seem to finish training! What is a workaround? CUDA or use the Majority Voter?