leeper / prediction

Tidy, Type-Safe 'prediction()' Methods
https://cran.r-project.org/package=prediction
Other
89 stars 14 forks source link

type = "probs" for a nnet::multinom() model? #23

Closed arcruz0 closed 6 years ago

arcruz0 commented 6 years ago

Hello, thanks for the great package!

I'm probably missing something, but I cannot find a way to use type = "probs" (or something that does the same in prediction()) for nnet::multinom() models. Here's a reprex:

model_reprex <- nnet::multinom(Species ~ Sepal.Width + Petal.Width, data = iris)

What I want is the following (the predicted probabilities for each possible response):

predict(model_reprex, newdata = data.frame(Sepal.Width = 3, Petal.Width = 1.2), type = "probs")

      setosa   versicolor    virginica 
1.114197e-06 9.992983e-01 7.005591e-04 

However, with prediction() I can only get:

prediction(model_reprex, at = list(Sepal.Width = 3, Petal.Width = 1.2))

Modal prediction (of 1 level) for 150 observations: 
 at(Sepal.Width) at(Petal.Width)      value
               3             1.2 versicolor

Thanks in advance!

leeper commented 6 years ago

It's all in the data frame returned by prediction(), it's just not printed in full by default:

> prediction(model_reprex, at = list(Sepal.Width = 3, Petal.Width = 1.2))
Modal prediction (of 1 level) for 150 observations: 
 at(Sepal.Width) at(Petal.Width)      value
               3             1.2 versicolor
> str(.Last.value)
Classes ‘prediction’ and 'data.frame':  150 obs. of  11 variables:
 $ Sepal.Length  : num  5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
 $ Sepal.Width   : num  3 3 3 3 3 3 3 3 3 3 ...
 $ Petal.Length  : num  1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
 $ Petal.Width   : num  1.2 1.2 1.2 1.2 1.2 1.2 1.2 1.2 1.2 1.2 ...
 $ Species       : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...
 $ fitted.class  : Factor w/ 3 levels "setosa","versicolor",..: 2 2 2 2 2 2 2 2 2 2 ...
 $ Pr(setosa)    : num  1.11e-06 1.11e-06 1.11e-06 1.11e-06 1.11e-06 ...
 $ Pr(versicolor): num  0.999 0.999 0.999 0.999 0.999 ...
 $ Pr(virginica) : num  0.000701 0.000701 0.000701 0.000701 0.000701 ...
 $ fitted        : num  1.11e-06 1.11e-06 1.11e-06 1.11e-06 1.11e-06 ...
 $ se.fitted     : num  NA NA NA NA NA NA NA NA NA NA ...
 - attr(*, "at")='data.frame':  1 obs. of  2 variables:
  ..$ Sepal.Width: num 3
  ..$ Petal.Width: num 1.2
 - attr(*, "model.class")= chr  "multinom" "nnet"
 - attr(*, "type")= chr NA
 - attr(*, "category")= chr "Pr(setosa)"

The class probabilities are the columns labeled Pr(...). The fitted column, by default, contains predictions for the reference category for the factor; this can be changed by setting the category argument in prediction(). The fitted.class column gives you the most likely class.

arcruz0 commented 6 years ago

Wow, that was really simple... thanks a lot!