hakaru-dev / hakaru

A probabilistic programming language
BSD 3-Clause "New" or "Revised" License
309 stars 30 forks source link

Simplify output regression #84

Closed cscherrer closed 7 years ago

cscherrer commented 7 years ago

Have a look at lda.hk and lda_simp2.hk here. A previous version of simplify produced efficient code with this (see lda_simp.hk), but the output of the current simplify includes nested weights and a call to betaFunc, neither of which were included previously.

I'll try to bisect to determine the point at which the result changes.

cscherrer commented 7 years ago

Result of git bisect:

6e3af0f7e16093ba1f07c8337b19fe67813ce7fa is the first bad commit
commit 6e3af0f7e16093ba1f07c8337b19fe67813ce7fa
Author: yuriy <toporoy@mcmaster.ca>
Date:   Wed May 24 22:12:37 2017 -0400

    More aggressive reduction of Partitions

:040000 040000 d8fcea790d744d3003dc4184af5ba84393813a7f 4c2e3a05f17e10ba2aca828a8619af777ebc19fa M  maple
yuriy0 commented 7 years ago

It turned out that 6e3af0f7e was only one of the problems! This change wasn't all that important, so I've reverted for now (cf5d2ac7b64a660696db59d87bc84b2f98771eb6) - in the future it will be replaced in a slightly more conservative manner (but I don't know what that manner is yet).

Some parts of the Maple code (seemingly, hack_Beta) seem to rely on getting sum and product instead of Sum and Product (but throwing in a subs right before hack_Beta doesn't do it either - so the issue may be more subtle than I believe). 84dcb747fd fixes this issue by ensuring that Sum/Product always become sum/product in the very early stages of simplification.

And perhaps stranger still, the above doesn't even work for Sum! Some while ago, Sum would be evaluated with a call to value before anything else was done to the program to be simplified. For some reason, the code which replaced that call to value (namely, eval_for_Simplify) (which was done in order to get a particular integral occurring in the input to Simplify to evaluate properly) doesn't do the job; but, as far as I can tell, the programs are identical whether value or eval_for_Simplify are called. This regression is addressed by 78438752f364acae5439358e7d427213225887ac.

The last point is the strangest one due to the following: the first two regressions cause simplification to fail horribly (betaFunc) but without the last change we get the following program, which is very similar to the 'expected' output, but has a `Plate' which shouldn't be there.

(fn topic_prior array(prob):
  (fn word_prior array(prob):
    (fn numDocs nat:
      (fn w array(nat):
        (fn doc array(nat):
          (fn z array(nat):
            (fn wordUpdate nat:
              (match (wordUpdate < size(w)):
                true:
                 weight(((product d from 0 to numDocs:
                           (product iT from 0 to size(topic_prior):
                             (product j from 0 to (summate dF from 0 to size(w):
                                                    (match (dF == wordUpdate):
                                                      true: 0
                                                      false:
                                                       (match ((d == doc[dF]) && (iT == z[dF])):
                                                         true: 1
                                                         false: 0))):
                               (nat2prob(j) + topic_prior[iT]))))
                          /
                         (product d from 0 to numDocs:
                           (product iT from 0 to (summate dF from 0 to size(w):
                                                   (match (dF == wordUpdate):
                                                     true: 0
                                                     false:
                                                      (match (d == doc[dF]):
                                                        true: 1
                                                        false: 0))):
                             (nat2prob(iT)
                               +
                              (summate dF from 0 to size(topic_prior): topic_prior[dF]))))),
                        xsx <~ plate k of size(topic_prior):
                                plate i of int2nat((nat2int(size(word_prior)) - 1)):
                                 beta((summate dF from (i + 1) to size(word_prior): word_prior[dF]),
                                      word_prior[i])
                        categorical(array zNewv of size(topic_prior):
                                     real2prob((prob2real((product d from 0 to size(w):
                                                            (product j from 0 to w[d]:
                                                              xsx[(match (d == wordUpdate):
                                                                    true: zNewv
                                                                    false: z[d])][j])))
                                                 *
                                                (product d from 0 to size(w):
                                                  (match ((w[d] + 1) == size(word_prior)):
                                                    true: 1
                                                    false:
                                                     (1
                                                       +
                                                      (prob2real(xsx[(match (d == wordUpdate):
                                                                       true: zNewv
                                                                       false: z[d])][w[d]])
                                                        *
                                                       (-1)))))
                                                 *
                                                prob2real((product d from 0 to numDocs:
                                                            (match ((d == doc[wordUpdate]) &&
                                                                    not(((nat2int(numDocs) - 1)
                                                                          <
                                                                         nat2int(doc[wordUpdate])))):
                                                              true:
                                                               (nat2prob((summate dF from 0 to size(w):
                                                                           (match (dF
                                                                                    ==
                                                                                   wordUpdate):
                                                                             true: 0
                                                                             false:
                                                                              (match ((d
                                                                                        ==
                                                                                       doc[dF]) &&
                                                                                      (zNewv
                                                                                        ==
                                                                                       z[dF])):
                                                                                true: 1
                                                                                false: 0))))
                                                                 +
                                                                topic_prior[zNewv])
                                                              false: 1)))
                                                 *
                                                prob2real(recip((product d from 0 to numDocs:
                                                                  (match (not(((nat2int(numDocs)
                                                                                 -
                                                                                1)
                                                                                <
                                                                               nat2int(doc[wordUpdate]))) &&
                                                                          (d == doc[wordUpdate])):
                                                                    true:
                                                                     (nat2prob((summate dF from 0 to size(w):
                                                                                 (match (dF
                                                                                          ==
                                                                                         wordUpdate):
                                                                                   true: 0
                                                                                   false:
                                                                                    (match (d
                                                                                             ==
                                                                                            doc[dF]):
                                                                                      true: 1
                                                                                      false: 0))))
                                                                       +
                                                                      (summate dF from 0 to size(topic_prior):
                                                                        topic_prior[dF]))
                                                                    false: 1))))
                                                 /
                                                (product d from 0 to size(w):
                                                  (summate dF from 0 to size(word_prior):
                                                    (prob2real((product j from 0 to dF:
                                                                 xsx[(match (d == wordUpdate):
                                                                       true: zNewv
                                                                       false: z[d])][j]))
                                                      *
                                                     (match ((dF + 1) == size(word_prior)):
                                                       true: 1
                                                       false:
                                                        (1
                                                          +
                                                         (prob2real(xsx[(match (d == wordUpdate):
                                                                          true: zNewv
                                                                          false: z[d])][dF])
                                                           *
                                                          (-1)))))))))))
                false: reject. measure(nat)))))))))

lda.hk has also been added as a Maple unit test (8d230e0384e84ce2bae9203ef9aed4df793a9a75 378781087ffa921141b41ffb3cf93c9e9b22a7e8).

yuriy0 commented 7 years ago

lda.hk now again simplifies to lda_simp.hk, so I'll close this. Thanks to everyone in the help tracking this down! There is still work to be done, but there are at least two different problems (probably three, maybe more!) at play, and they are fairly low level, whereas this issue describes a high level problem (in other words, these different problems may deserve their issues).

As an aside, the latest commit which includes all of this work (as of writing, 4d50b3a8b0bbd80390729a83222487ddc17bc768) is a good candidate for the new tip of the stable branch. It is very likely that master will soon be highly un-stable again.