covartech / PRT

Pattern Recognition Toolbox for MATLAB
http://covartech.github.io/
MIT License
145 stars 70 forks source link

Pass userData to output data set in cross-validation or binaryToMaryOneVsAll #57

Open cratto opened 7 years ago

cratto commented 7 years ago

I wrote a new prtClass in which metadata is written out to DataSet.userData in the runAction() method.

However, when I call crossValidate() on that prtClass or use it as a base classifier in prtClassBinaryToMaryOneVsAll, the userData is not saved to the output.

I would like it such that userData was saved to the DataSet that comes out of either prtClass.crossValidate() or prtClassBinaryToMaryOneVsAll.runAction() Is this an easy internal PRT fix?

patrickkwang commented 7 years ago

For both of these, it's not obvious what the desired result is. crossValidate() calls runAction() for each fold, but currently yields the userData produced for the first fold only. prtClassBinaryToMaryOneVsAll runs runAction() for each class, and discards the userData entirely. We think the best solution is to produce a struct vector where each element comes from one crossvalidation fold or from one of the binary classification runs. This was implemented in the latest commit 85b6087b2917014530096c4f65c78c15189f966a

Here's my example script. You'll need to add the attached class definition to your path. Change the extension to .m (apparently GitHub doesn't like m-files).

ds = prtDataGenUnimodal;
clsUserdata = prtClassAddUserdata;
keys = round(rand(ds.nObservations,1));
results = clsUserdata.crossValidate(ds,keys);
results.userData

%%
dsMulti = prtDataGenMary;
clsUserdata = prtClassAddUserdata;
multiCls = prtClassBinaryToMaryOneVsAll('baseClassifier',clsUserdata);
results = multiCls.rt(dsMulti);
results.userData

prtClassAddUserdata.txt

cratto commented 7 years ago

Thanks Patrick, this works fine for my purposes. Will this be merged into master at some point?

cratto commented 6 years ago

Revisitng this, because it's no longer working for me.

It appears the struct written to userData on the first fold is dropped at line 242 of prtAction.run: dsOut = postRunProcessing(self, dsIn, dsOut); Before running this line, dsOut.userData is a struct with one field, 'a'. After running it, dsOut.userData is an empty struct.

If you let cross-validation run until it gets to prtAction.crossValidate, line 380: dsOut = dsOut.acquireNonDataAttributesFrom(dsIn); Before running the line, dsOut.userData is a 2x1 empty struct with no fields, but after the line it is 1x1.

cratto commented 6 years ago

I did a git reset --soft 85b6087 on my local copy of devel and still had the same problem.

The script works after re-cloning the devel branch and doing a git reset --hard 85b6087. Are there changes made in later commits that may have undone this?

patrickkwang commented 6 years ago

https://github.com/covartech/PRT/commit/690bdeee119571d49ebae45bf07353c2111b88dd

You're right. We had offsetting bugs. One of them was fixed in https://github.com/covartech/PRT/commit/8d685811f75f07a562ef25fe908aad947a869375, breaking our userData handling. I have fixed the other.

patrickkwang commented 6 years ago

See #66 (@peterTorrione). @cratto, does it make sense to use observationInfo instead of userData for your application?