Open cratto opened 7 years ago
For both of these, it's not obvious what the desired result is. crossValidate()
calls runAction()
for each fold, but currently yields the userData
produced for the first fold only. prtClassBinaryToMaryOneVsAll
runs runAction()
for each class, and discards the userData entirely. We think the best solution is to produce a struct
vector where each element comes from one crossvalidation fold or from one of the binary classification runs. This was implemented in the latest commit 85b6087b2917014530096c4f65c78c15189f966a
Here's my example script. You'll need to add the attached class definition to your path. Change the extension to .m (apparently GitHub doesn't like m-files).
ds = prtDataGenUnimodal;
clsUserdata = prtClassAddUserdata;
keys = round(rand(ds.nObservations,1));
results = clsUserdata.crossValidate(ds,keys);
results.userData
%%
dsMulti = prtDataGenMary;
clsUserdata = prtClassAddUserdata;
multiCls = prtClassBinaryToMaryOneVsAll('baseClassifier',clsUserdata);
results = multiCls.rt(dsMulti);
results.userData
Thanks Patrick, this works fine for my purposes. Will this be merged into master at some point?
Revisitng this, because it's no longer working for me.
It appears the struct written to userData on the first fold is dropped at line 242 of prtAction.run: dsOut = postRunProcessing(self, dsIn, dsOut); Before running this line, dsOut.userData is a struct with one field, 'a'. After running it, dsOut.userData is an empty struct.
If you let cross-validation run until it gets to prtAction.crossValidate, line 380: dsOut = dsOut.acquireNonDataAttributesFrom(dsIn); Before running the line, dsOut.userData is a 2x1 empty struct with no fields, but after the line it is 1x1.
I did a git reset --soft 85b6087 on my local copy of devel and still had the same problem.
The script works after re-cloning the devel branch and doing a git reset --hard 85b6087. Are there changes made in later commits that may have undone this?
https://github.com/covartech/PRT/commit/690bdeee119571d49ebae45bf07353c2111b88dd
You're right. We had offsetting bugs. One of them was fixed in https://github.com/covartech/PRT/commit/8d685811f75f07a562ef25fe908aad947a869375, breaking our userData handling. I have fixed the other.
See #66 (@peterTorrione). @cratto, does it make sense to use observationInfo instead of userData for your application?
I wrote a new prtClass in which metadata is written out to DataSet.userData in the runAction() method.
However, when I call crossValidate() on that prtClass or use it as a base classifier in prtClassBinaryToMaryOneVsAll, the userData is not saved to the output.
I would like it such that userData was saved to the DataSet that comes out of either prtClass.crossValidate() or prtClassBinaryToMaryOneVsAll.runAction() Is this an easy internal PRT fix?