Closed rudi77 closed 4 years ago
PFI requires the predictor of your model casted into a ISingleFeaturePredictionTransformer
. Typically you'll need to take the predictor out of your resulting model by taking the .LastTransformer, then I would suggest casting it into a ISingleFeaturePredictionTransformer<object>
. So something like this.
var predictor = (result.BestRun.Model as TransformerChain<ITransformer>).LastTransformer as ISingleFeaturePredictionTransformer<object>
I don't remember if it's necessary to take the .LastTransformer from the result of AutoML, though. So you might be able to skip that step before doing the cast.
Another cast that would generally work, but that I wouldn't recommend for your use case, can be found in the following sample:
In that case we're casting the predictor into a BinaryPredictionTransformer<CalibratedModelParametersBase<LinearBinaryModelParameters, PlattCalibrator>>
which is an ISingleFeaturePredictionTransformer<CalibratedModelParametersBase<LinearBinaryModelParameters, PlattCalibrator>>
. There, we can do that cast, because we know beforehand what type of predictor we expect. Typically you won't know what predictor you get from running AutoML, so that's why I suggest using the first cast I suggested 😄
I will close this issue now. Please feel free to reopen this if you're still having problems with it. Thanks.
I am using AutoML API for finding the best binary classification model for a given dataset. As a next step I want to retrieve the most important features with the BinaryClassification.PermuationFeatureImportance method as it is shown in the code below. The problem is that this method expects an instance of ISingleFeaturePredictionTransformer as the first argument but the
result.BestRun.Model
is of TypeITransformer
and I don't know how I can cast or convert the ITransformer into the expected Type. So far I couldn't find any samples for this case. Any suggestions how I can do the feature permutation?var result = experiment.Execute(trainData, LabelColumnName, progressHandler: progressHandler); var permutations = MlContext.BinaryClassification .PermutationFeatureImportance(result.BestRun.Model, trainData, LabelColumnName, permutationCount: 10);