nutterb / HydeNet

Hybrid Decision Networks in R
Other
23 stars 3 forks source link

writeJagsModel() and cpt nodes #64

Closed jarrod-dalton closed 9 years ago

jarrod-dalton commented 9 years ago

The error message is from a call to writeNetworkModel() but the problem is apparently in writeJagsModel().

library(HydeNet)
net <- HydeNetwork(~ initialAces | card1*card2
                   + initialPoints | card1*card2
                   + highUpcard | dealerUpcard
                   + hit1 | initialPoints*highUpcard
                   + acesAfterCard3 | initialAces*card3
                   + pointsAfterCard3 | card1*card2*card3*acesAfterCard3
                   + hit2 | pointsAfterCard3*highUpcard
                   + acesAfterCard4 | acesAfterCard3*card4
                   + pointsAfterCard4 | card1*card2*card3*card4*acesAfterCard4
                   + hit3 | pointsAfterCard4*highUpcard
                   + acesAfterCard5 | acesAfterCard4*card5
                   + pointsAfterCard5 | card1*card2*card3*card4*card5*acesAfterCard5
                   + playerFinalPoints | initialPoints*hit1*pointsAfterCard3
                                         *hit2*pointsAfterCard4*hit3*pointsAfterCard5
                   + dealerOutcome | dealerUpcard
                   + payoff | playerFinalPoints*dealerOutcome)

#####################################
# Random Variable Nodes
#####################################

# Note: The calls to 'setNode' below are going to generate messages
#       that validation has been ignored.  'setNode' assumes that when
#       you pass a character string as a parameter definition, you 
#       are passing JAGS code.  See the Validation section of the 
#       'setNode' documentation for more information.
#       After the first call, we'll turn off validation

cardProbs  <- c(rep(1/13,8), 4/13, 1/13)  # probs. for 2, 3, ..., 9, (10-K), A

net <- setNode(net, card1, nodeType="dcat",  pi=vectorProbs(p=cardProbs, card1))
net <- setNode(net, card2, nodeType="dcat",  pi=vectorProbs(p=cardProbs, card2),
               validate=FALSE)
net <- setNode(net, card3, nodeType="dcat",  pi=vectorProbs(p=cardProbs, card3),
               validate=FALSE)
net <- setNode(net, card4, nodeType="dcat",  pi=vectorProbs(p=cardProbs, card4),
               validate=FALSE)
net <- setNode(net, card5, nodeType="dcat",  pi=vectorProbs(p=cardProbs, card5),
               validate=FALSE)

net <- setNode(net, dealerUpcard, nodeType="dcat",
               pi=vectorProbs(p=cardProbs, dealerUpcard),
               validate=FALSE)

#Note: node dealerOutcome will be defined below, following some discussion 
#      about its conditional probability distribution.

#####################################
# Deterministic Nodes
#####################################
net <- setNode(net, highUpcard,     "determ", define=fromFormula(),
               nodeFormula = highUpcard ~ ifelse(dealerUpcard > 8, 1, 0))

net <- setNode(net, initialAces,    "determ", define=fromFormula(),
               nodeFormula = initialAces ~ ifelse(card1==10,1,0) + ifelse(card2==10,1,0))

net <- setNode(net, acesAfterCard3, "determ", define=fromFormula(),
               nodeFormula = acesAfterCard3 ~ initialAces + ifelse(card3==10,1,0))

net <- setNode(net, acesAfterCard4, "determ", define=fromFormula(),
               nodeFormula = acesAfterCard4 ~ acesAfterCard3 + ifelse(card4==10,1,0))

net <- setNode(net, acesAfterCard5, "determ", define=fromFormula(),
               nodeFormula = acesAfterCard5 ~ acesAfterCard4 + ifelse(card5==10,1,0))

net <- setNode(net, initialPoints, "determ", define=fromFormula(),
               nodeFormula = initialPoints ~ card1+card2+2)

net <- setNode(net, pointsAfterCard3, "determ", define=fromFormula(),
               nodeFormula = pointsAfterCard3 ~
                 ifelse(acesAfterCard3 == 3,
                    13,
                    ifelse(acesAfterCard3 == 2,
                       card1 + card2 + card3 + 3 - 10,
                       ifelse(acesAfterCard3 == 1,
                          ifelse(card1 + card2 + card3 + 3 > 22,
                             card1 + card2 + card3 + 3 - 10,
                             card1 + card2 + card3 + 3),
                          card1 + card2 + card3 + 3
                       )
                    )
                 )
              )

net <- setNode(net, pointsAfterCard4, "determ", define=fromFormula(),
               nodeFormula = pointsAfterCard4 ~
                 ifelse(acesAfterCard4 == 4,
                    14,
                    ifelse(acesAfterCard4 == 3,
                       ifelse(card1 + card2 + card3 + card4 + 4 > 38,
                          card1 + card2 + card3 + card4 + 4 - 30,
                          card1 + card2 + card3 + card4 + 4 - 20
                       ),
                       ifelse(acesAfterCard4 > 0,
                          ifelse(card1 + card2 + card3 + card4 + 4 > 22,
                                 card1 + card2 + card3 + card4 + 4 - 10,
                                 card1 + card2 + card3 + card4 + 4
                          ),
                          card1 + card2 + card3 + card4 + 4
                       )
                    )
                 )
              )

net <- setNode(net, pointsAfterCard5, "determ", define=fromFormula(),
               nodeFormula = pointsAfterCard5 ~ 
                 ifelse(acesAfterCard5 == 5,
                    15,
                    ifelse(acesAfterCard5 == 4,
                       ifelse(card1 + card2 + card3 + card4 + card5 + 5 > 51,
                          card1 + card2 + card3 + card4 + card5 + 5 - 40,
                          card1 + card2 + card3 + card4 + card5 + 5 - 30
                        ),
                       ifelse(acesAfterCard5 == 3,
                          ifelse(card1 + card2 + card3 + card4 + card5 + 5 > 51,
                             card1 + card2 + card3 + card4 + card5 + 5 - 30,
                             card1 + card2 + card3 + card4 + card5 + 5 - 20
                          ),
                          ifelse(acesAfterCard5 == 2,
                             ifelse(card1 + card2 + card3 + card4 + card5 + 5 > 31,
                                card1 + card2 + card3 + card4 + card5 + 5 - 20,
                                card1 + card2 + card3 + card4 + card5 + 5 - 10
                             ),
                             ifelse(acesAfterCard5 > 0,
                                ifelse(card1 + card2 + card3 + card4 + card5 + 5 > 22,
                                   card1 + card2 + card3 + card4 + card5 + 5 - 10,
                                   card1 + card2 + card3 + card4 + card5 + 5
                                ),
                                card1 + card2 + card3 + card4 + card5 + 5
                             )
                          )
                       )
                    )
                 )
              )

net <- setNode(net, playerFinalPoints, "determ", define=fromFormula(),
               nodeFormula = playerFinalPoints ~ 
                 ifelse(hit1 == 0,
                    initialPoints,
                    ifelse(hit2 == 0,
                           pointsAfterCard3,
                           ifelse(hit3 == 0, pointsAfterCard4, pointsAfterCard5)
                    )
                 )
              )
data(BJDealer)
dealerOutcome.cpt <- cpt(dealerOutcome ~ dealerUpcard,
                         data = BJDealer,
                         wt   = BJDealer$probability)
round(dealerOutcome.cpt,3)
net <- setNodeModels(net, dealerOutcome.cpt)
writeNetworkModel(net, pretty=TRUE)
Error in if (network$nodeFitter[[node_str]] == "cpt") { : 
  argument is of length zero
jarrod-dalton commented 9 years ago

I think I fixed it, but you should review. The problem was that we were doing logic on the nodeFitter list elements but sometimes they were NULL.

So I just put a little bitwise and in there to first ensure that the nodeFitter element is populated before checking its equivalence to "cpt".

  #*** Type 'dcat'
  else if (network$nodeType[[node_str]] == "dcat"){
    if (length(network$nodeFitter[[node_str]]) > 0 && network$nodeFitter[[node_str]] == "cpt"){
nutterb commented 9 years ago

I changed length(network$nodeFitter[[node_str]]) > 0 to !is.null(network$nodeFitter[[node_str]]), but otherwise, this checks out.