pbreheny / visreg

Visualization of regression functions
http://pbreheny.github.io/visreg/
61 stars 18 forks source link

visreg2d for multinom? #73

Open kwhkim opened 4 years ago

kwhkim commented 4 years ago

Hi,

great package!

I wish we could do visreg2d for multinomial regression,

something like the following,

mu1 = c(-2,2)
mu2 = c(0,0)
mu3 = c(2,-2)

sigma = matrix(c(2,1.4,1.4,2), 2,2)

library(mvtnorm)
library(ggplot2)

dat1 <- rmvnorm(50, mu1, sigma)
dat2 <- rmvnorm(50, mu2, sigma)
dat3 <- rmvnorm(50, mu3, sigma)

dat <- as.data.frame(rbind(dat1, dat2, dat3))
dat$class <- factor(c(rep("A", nrow(dat1)),
                      rep("B", nrow(dat2)),
                      rep("C", nrow(dat3))))

colnames(dat) <- c('x', 'y', 'class')

head(dat)

library(nnet) 
fitMUL <- multinom(class ~ x + y, data=dat)
summary(fitMUL)

#visreg2d(fitMUL, 'x', 'y') # This is not working!
visreg(fitMUL, 'x', collapse=TRUE, 
       overlay=TRUE, 
       ylab='Probability', 
       ylim=c(0,1))

newdata = expand.grid(x=seq(-4,4,0.1), y=seq(-4,4,0.1))
predy <- predict(fitMUL, newdata=newdata)

ggdat <- cbind(newdata, data.frame(predy=predy))
ggplot(ggdat, aes(x=x, y=y, col=predy)) + geom_point()