Closed dataframing closed 5 years ago
Merging #1488 into master will decrease coverage by
0.17%
. The diff coverage is80.95%
.
@@ Coverage Diff @@
## master #1488 +/- ##
==========================================
- Coverage 97.65% 97.48% -0.18%
==========================================
Files 56 56
Lines 2093 2109 +16
Branches 336 340 +4
==========================================
+ Hits 2044 2056 +12
- Misses 22 25 +3
- Partials 27 28 +1
Impacted Files | Coverage Δ | |
---|---|---|
snorkel/labeling/model/label_model.py | 94.81% <80.95%> (-1.02%) |
:arrow_down: |
@henryre I can't seem to tag a reviewer for this PR, but let me know if you have any questions/comments/concerns. Also attempting to tag @plison for an extra pair of eyes.
Edit: it seems like codecov is failing because I'm not properly testing the logging statements for when max(scores) == -1
and n_iter == max_n_iter
. I've been looking for another test on the label model that induces a case where max(scores) == -1
and can't seem to definitively find one — @henryre would you be able to point me in the right direction? Does such a test exist? Alternatively, if it's not within scope for this PR, we can just remove the logging.
@dataframing to start: thank you so much for pointing out this issue in such a detailed and thorough way, and for this PR! We owe contributors like you big time for this kind of help and support :)
Also, apologies for some lack of coordination on our end. We have a planned update that will (at least partially) address this issue in a more principled way than random sampling, i.e. by relying on some additional default assumptions applied to the underlying mathematical model that yields valid permutations of the learned weights. We will be trying to get to this soon, and so I'd probably hold off on further development here for now. I'll then circle back and see if there's a way to merge our solution with some of the work done here. Thanks!
@ajratner not a problem! It was a fun exercise. I'm looking forward to the update, and I'll be on the lookout for other ways to contribute. Thanks to the team for their hard work on this! 👍
@dataframing @ajratner miscommunication was on me. Thanks for the effort on this!
I've been experimenting with @dataframing 's code, as I am working on a classification problem with 18 distinct classes (so enumerating all possible permutations is obviously not feasible). Just wanted to say that the code runs, but I noticed a bug at the end of the (proposed) _break_col_permutation_symmetry
method, namely that the mu
parameter is modified (line 866-868 of label_model.py
) even when a valid μ value isn't found. This is a bad idea, as in that case, the mu
parameter will transformed into a random permutation that is much worse than the initial mu
value. I think a return
should be added on line 865 to avoid changing the mu
parameter when no solution is found.
(just wanted to mention this in case someome else tries that code and stumbles upon the same problem).
I made a comment a moment ago, but I think it's wrong. Yes, I think if we're only sampling a fraction of the space of all permutations, we should make sure to not update the mu
instance unless we have a valid (non -1) entry in scores
.
@ajratner I'm on the lookout for this also, as I just moved from 4 classes to 12 and I noticed the LabelModel does not seem to finish training! What is a workaround? CUDA or use the Majority Voter?
Description of proposed changes
This PR addresses #1486, where we found that a one unit increase in target cardinality (
k
) results in an exponential (~8-10x) increase in runtime, due to attempting to evaluate allk!
permutations of our labeling functions while searching for an idealμ
inLabelModel._break_col_permutation_symmetry
. This makes usingLabelModel
a non-starter for multi-class classification tasks with a cardinality greater than ~8.To address this issue, we instead implement a randomized search policy over all
k!
permutations – parameterized only byn_iter
– which is the maximum number of distinct permutations we'd like to consider before early-stopping.Considerations
In general there are a few questions we should be aware of:
cardinality -> number of iterations
? Sincecardinality -> # permutations
is exponential, we'd need to linearize by applying the logarithm. But I didn't bother to try and craft a function, since it seemed more complicated than it was worth.k!
permutations. It may re-generate a previous permutation, but the final sample of permutations is unique.itertools.permutations
to generate our samples, but found it a bit difficult to reconcile with randomly sampling along the entire space. We could chain the returned generator withitertools.islice
, but I'm not 100% sure what the implication would be of selecting a contiguous sequence ofn_iter
permutations and claiming it's "random." Thoughts welcome here!np.random.permutation(k)
and memoizes prior samples' hash in a set, instead. That way, we're not generating a set ofn_iter
tuples of length 17.n_iter
in theLabelModelConfig
to 1 million, which gives us an amortized performance that's between cardinalities 9 and 10. This can be set by the user at model initialization time, but worth considering if that's a proper default or if we should be more conservative.Related issue(s)
Fixes #1486.
Test plan
I've added a test
test_label_model_large_multiclass
underTestLabelModelAdvanced
, which is marked as a complex test suite. It attempts to run aLabelModel
withcardinality = 20
, which would take a long time without a capped number of iterations. Instead, I've capped the number of iterations to 100, so it finishes pretty quickly (a few seconds, modulo local computing resources).To test, run:
Anecdotally, I've also run the example from #1486 with
10 <= N_CLASSES <= 13
and the runtime is constant with respect to the providedn_iter
. Atn_iter: int = 500_000
, here are examples of runs on my local machine:Note: the runtime still increases a bit with every increase in cardinality, but that's because each sampled permutation is one item larger, which adds up over
n_iter
samples. Anecdotally, ask
increases, the time it takes to samplen_iter
unique permutations decreases (fewer and fewer opportunities for duplicate sampling), but the core search for an idealμ
increases in runtime (larger matrices being generated and multiplied).Checklist
Need help on these? Just ask!
tox -e complex
and/ortox -e spark
if appropriate.