VowpalWabbit / vowpal_wabbit

Vowpal Wabbit is a machine learning system which pushes the frontier of machine learning with techniques such as online, hashing, allreduce, reductions, learning2search, active, and interactive learning.
https://vowpalwabbit.org
Other
8.49k stars 1.93k forks source link

CSOAA in Python not returning probabilities #3331

Closed sergeyf closed 3 years ago

sergeyf commented 3 years ago

Hello,

Thanks for the great work on VW over the many years!

I'm using the Python package with the csoaa option to deal with a multilabel problem. It seems to train fine, but I can't get it to return probabilities. This is how I trained without any errors:

from vowpalwabbit import pyvw

vw = pyvw.vw(quiet=True, csoaa=y_train.shape[1], b=26, ngram="d2", probabilities=True)
passes = 1
for n in range(passes):
    print(n)
    if n >= 1:
        examples = shuffle(Xy_train)
    else:
        examples = Xy_train
    for idx, example in enumerate(examples):
        vw.learn(example)

Where Xy_train has things like

Xy_train[0]
Out: '1:1.0 2:1.0 3:1.0 4:0.0 5:1.0 6:1.0 7:1.0 8:1.0 9:1.0 10:1.0 11:1.0 12:1.0 13:1.0 14:1.0 15:1.0 16:1.0 17:1.0 18:1.0 19:1.0 20:1.0 21:1.0 22:1.0 23:1.0 24:1.0 |text a transient based real time scheduling algorithm in fms 

At test time, I have data like:

Xy_val[0]
Out: '1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |text covid associated ards treated with dexamethasone codex study design and rationale for a randomized trial

And doing vw.predict(Xy_val[0]) probabilities aren't returned, but just a single integer. How can one get probabilities out of this? I've tried vw.predict(Xy_val[0], i) for various i and no luck.

Thank you.

sergeyf commented 3 years ago

I also tried it with the sklearn interface and the bug is more obvious:

from vowpalwabbit.sklearn_vw import VWMultiClassifier

vw = VWMultiClassifier(
    convert_to_vw=False, passes=1, csoaa=y_train.shape[1], b=26, ngram="d2", probabilities=True, l2=0.001
)
vw.fit(Xy_train)

vw.predict_proba(Xy_val[0:1])

You get:

Out: 
array([[22., 22., 22., 22., 22., 22., 22., 22., 22., 22., 22., 22., 22.,
        22., 22., 22., 22., 22., 22., 22., 22., 22., 22., 22.]])
bassmang commented 3 years ago

Hi Sergey,

Thank you for reporting this bug and all of the documentation you provided. csoaa currently does not support the probabilities flag. We'll work on adding this functionality this week and get back to you.

Thanks!

sergeyf commented 3 years ago

Wonderful, thank you!

bassmang commented 3 years ago

Hi @sergeyf, we've looked into your issue further and come to the conclusion that the --probabilities flag doesn't quite make sense in the context of --csoaa. To address you're issue I've put in a fix to add the --probabilities flag to --multilabel_oaa instead. This should be sufficient for you example above since each class has a weight of either 0 or 1. Here is an example of how you could use --multilabel_oaa with the --probabilities flag:

from vowpalwabbit import pyvw
Xy_train = ['0,1,2 |text a transient based real time scheduling algorithm in fms']
Xy_train.append('1,2 |text a transient based')
vw = pyvw.vw(quiet=True, multilabel_oaa=3, probabilities=True, loss_function='logistic')
passes = 1
for n in range(passes):
    for idx, example in enumerate(Xy_train):
        vw.learn(example)

Xy_val = "|text a"
vw.predict(Xy_val)

Output: [0.2910301685333252, 0.3539791405200958, 0.35499072074890137]

sergeyf commented 3 years ago

That is perfect, thank you for the fix!

bassmang commented 3 years ago

Hi Sergey, this fix has just been merged. Here's a description of the new functionality:

  1. Using –probabilities with csoaa will throw an error
  2. Using –link logistic with –probabilities will not make a difference in oaa, csoaa_ldf, and multilabel_oaa. In oaa and csoaa_ldf, this flag will be removed if seen, and the linking logic is done within the reduction (not in scorer.cc). In multilabel_oaa, the –link logistic flag will be added if not seen and the linking logic occurs in scorer.cc
  3. multilabel_oaa outputs probabilities as logistic scores -- no normalization!

Please let me know if you have any questions.