hakaru-dev / hakaru

A probabilistic programming language
BSD 3-Clause "New" or "Revised" License
309 stars 30 forks source link

Some part of the simplifier still relies on receiving `sum` instead of `Sum` #88

Open yuriy0 opened 7 years ago

yuriy0 commented 7 years ago

It is yet unknown which part of the simplifier that is. I have scoured every occurrence of sum and cannot find one in a negative position which isn't also accompanied by Sum. This leads me to believe that some function, quite possibly a Maple function (i.e. not part of Hakaru but Maple proper), recognizes sum but not Sum.

Currently this is handled by always evaluating sums before calling any part of the simplifier. As we saw in #87, passing non-inert sum (and probably product and int) to certain Maple functions exposes a Maple bug, so we would really like to be able to deal with only Sum and not sum (unless we expect a sum to evaluate, but then the sum should disappear).

To observe the incorrect behaviour, change line 136 from above to `Sum`=`sum` and try to simplify examples/lda2.hk. The output will be:

(fn topic_prior array(prob):
  (fn word_prior array(prob):
    (fn numDocs nat:
      (fn w array(nat):
        (fn doc array(nat):
          (fn z array(nat):
            (fn wordUpdate nat:
              (match (wordUpdate < size(w)):
                true:
                 weight((((product d from 0 to int2nat((nat2int(size(topic_prior))
                                                         -
                                                        1)):
                            betaFunc((summate dP from (d + 1) to size(topic_prior):
                                       topic_prior[dP]),
                                     topic_prior[d]))
                           **
                          (nat2real(numDocs) * (-1)))
                          *
                         (product d from 0 to numDocs:
                           (product i from 0 to int2nat((nat2int(size(topic_prior)) - 1)):
                             betaFunc(((summate dP from (i + 1) to size(topic_prior):
                                         topic_prior[dP])
                                        +
                                       nat2prob((summate dP from 0 to size(w):
                                                  (match (not(((nat2int((match (dP == wordUpdate):
                                                                          true: 0
                                                                          false: z[dP]))
                                                                 -
                                                                1)
                                                                <
                                                               nat2int(i))) &&
                                                          (d == doc[dP])):
                                                    true: 1
                                                    false: 0)))),
                                      topic_prior[i])))),
                        xsl <~ plate d of numDocs:
                                plate i of int2nat((nat2int(size(topic_prior)) - 1)):
                                 beta(((summate dP from (i + 1) to size(topic_prior):
                                         topic_prior[dP])
                                        +
                                       nat2prob((summate dP from 0 to size(w):
                                                  (match (not(((nat2int((match (dP == wordUpdate):
                                                                          true: 0
                                                                          false: z[dP]))
                                                                 -
                                                                1)
                                                                <
                                                               nat2int(i))) &&
                                                          (d == doc[dP])):
                                                    true: 1
                                                    false: 0)))),
                                      topic_prior[i])
                        weight(recip((product d from 0 to numDocs:
                                       ((summate dP from 0 to 1:
                                          (product j from 0 to dP: xsl[d][j]))
                                         **
                                        nat2real((summate dP from 0 to size(w):
                                                   (match (d == doc[dP]):
                                                     true: 1
                                                     false: 0)))))),
                               xsj <~ plate k of size(topic_prior):
                                       plate i of int2nat((nat2int(size(word_prior)) - 1)):
                                        beta((summate x0 from (i + 1) to size(word_prior):
                                               word_prior[x0]),
                                             word_prior[i])
                               categorical(array zNewh of size(topic_prior):
                                            real2prob((prob2real((product d from 0 to size(w):
                                                                   (product j from 0 to (match (d
                                                                                                 ==
                                                                                                wordUpdate):
                                                                                          true:
                                                                                           zNewh
                                                                                          false:
                                                                                           0):
                                                                     xsl[doc[d]][((match (d
                                                                                           ==
                                                                                          wordUpdate):
                                                                                    true: 0
                                                                                    false: z[d])
                                                                                   +
                                                                                  j)])))
                                                        *
                                                       (product d from 0 to size(w):
                                                         (match (((match (d == wordUpdate):
                                                                    true: zNewh
                                                                    false: z[d])
                                                                   +
                                                                  1)
                                                                  ==
                                                                 size(topic_prior)):
                                                           true: 1
                                                           false:
                                                            (1
                                                              +
                                                             (prob2real(xsl[doc[d]][(match (d
                                                                                             ==
                                                                                            wordUpdate):
                                                                                      true: zNewh
                                                                                      false: z[d])])
                                                               *
                                                              (-1)))))
                                                        *
                                                       prob2real((product d from 0 to size(w):
                                                                   (product j from 0 to w[d]:
                                                                     xsj[(match (d == wordUpdate):
                                                                           true: zNewh
                                                                           false: z[d])][j])))
                                                        *
                                                       (product d from 0 to size(w):
                                                         (match ((w[d] + 1) == size(word_prior)):
                                                           true: 1
                                                           false:
                                                            (1
                                                              +
                                                             (prob2real(xsj[(match (d
                                                                                     ==
                                                                                    wordUpdate):
                                                                              true: zNewh
                                                                              false: z[d])][w[d]])
                                                               *
                                                              (-1)))))
                                                        /
                                                       (product d from 0 to size(w):
                                                         (summate x0 from 0 to size(word_prior):
                                                           (prob2real((product j from 0 to x0:
                                                                        xsj[(match (d
                                                                                     ==
                                                                                    wordUpdate):
                                                                              true: zNewh
                                                                              false: z[d])][j]))
                                                             *
                                                            (match ((x0 + 1) == size(word_prior)):
                                                              true: 1
                                                              false:
                                                               (1
                                                                 +
                                                                (prob2real(xsj[(match (d
                                                                                        ==
                                                                                       wordUpdate):
                                                                                 true: zNewh
                                                                                 false: z[d])][x0])
                                                                  *
                                                                 (-1))))))))))))
                false: reject. measure(nat)))))))))
ccshan commented 7 years ago

Would one way to debug the difference sum makes be to use printlevel:=1000 and diff -i?