hakaru-dev / hakaru

A probabilistic programming language
BSD 3-Clause "New" or "Revised" License
311 stars 30 forks source link

testKernel failure #94

Closed ccshan closed 7 years ago

ccshan commented 7 years ago
### Failure in: 6:RoundTrip:6:2:testKernel     
haskell/Tests/TestTools.hs:92
expected:
(fn x3 real: 
  x2 <~ normal(x3, nat2prob(1))
  x1 = (match (nat2prob(1)
                < 
               exp((negate(nat2real(1))
                     / 
                    nat2real(50)
                     * 
                    (x2 - x3)
                     * 
                    (x2 + x3)))): 
         true: nat2prob(1)
         false: 
          exp((negate(nat2real(1)) / nat2real(50) * (x2 - x3) * (x2 + x3))))
  x0 <~ x0 <~ categorical([x1,
                           real2prob((prob2real(nat2prob(1)) - prob2real(x1)))])
        return [true, false][x0]
  return (match x0: 
           true: x2
           false: x3))
but got:
(fn x3 real: 
  weight(nat2prob(2),
         x27 <~ lebesgue((prob2real(∞) * int2real(-1)), prob2real(∞))
         weight(1/2,
                (match not((nat2prob(1)
                             < 
                            (exp(((x27 ^ 2) * -1/50))
                              * 
                             exp(((x3 ^ 2) * prob2real((1/50))))))): 
                  true: 
                   weight((sqrt(nat2prob(2))
                            / 
                           sqrt(pi)
                            * 
                           exp(((x27 ^ 2) * -13/25))
                            * 
                           exp((x27 * x3))
                            * 
                           exp(((x3 ^ 2) * -12/25))
                            / 
                           2),
                          return x27)
                  false: reject. measure(real))) <|> 
         weight(1/2,
                (match not((nat2prob(1)
                             < 
                            (exp(((x27 ^ 2) * -1/50))
                              * 
                             exp(((x3 ^ 2) * prob2real((1/50))))))): 
                  true: 
                   weight(real2prob((prob2real(sqrt(nat2prob(2)))
                                      * 
                                     (prob2real((exp(((x27 ^ 2) * -1/50))
                                                  * 
                                                 exp(((x3 ^ 2) * prob2real((1/50))))))
                                       + 
                                      int2real(-1))
                                      * 
                                     prob2real(exp(((x27 ^ 2) * -1/2)))
                                      * 
                                     prob2real(exp((x27 * x3)))
                                      * 
                                     prob2real(exp(((x3 ^ 2) * -1/2)))
                                      * 
                                     prob2real(recip(sqrt(pi)))
                                      * 
                                     -1/2)),
                          return x3)
                  false: reject. measure(real)))))
yuriy0 commented 7 years ago

The simplification of this program has changed slightly; for reference, here is the new result:

(fn x3 real: 
  x27 <~ lebesgue((prob2real(∞) * (-1)), prob2real(∞))
  (match not((1
               < 
              (exp(((x27 ^ 2) * -1/50)) * exp(((x3 ^ 2) / 50))))): 
    true: 
     weight((sqrt(2)
              / 
             sqrt(pi)
              * 
             exp(((x27 ^ 2) * -13/25))
              * 
             exp((x27 * x3))
              * 
             exp(((x3 ^ 2) * -12/25))
              / 
             2),
            return x27) <|> 
     weight(real2prob((prob2real(sqrt(2))
                        * 
                       (prob2real((exp(((x27 ^ 2) * -1/50)) * exp(((x3 ^ 2) / 50)))) - 1)
                        * 
                       prob2real(exp(((x27 ^ 2) * -1/2)))
                        * 
                       prob2real(exp((x27 * x3)))
                        * 
                       prob2real(exp(((x3 ^ 2) * -1/2)))
                        * 
                       prob2real(recip(sqrt(pi)))
                        * 
                       -1/2)),
            return x3)
    false: reject. measure(real)))

It now contains the superposition of two different weights which have different bodies.

My issue now is that I don't know what should happen with the lebesgue. It starts of as a normal and the weight of that normal is pushed in and combined with other weights in the expression. I don't see how we could eliminate the lebesgue altogether here (the impression I got today was that this lebesgue should be eliminated).

Can someone offer some guidance as to the desired output? Perhaps we just want to recognize normal instead of lebesgue? I think pulling the common factor of exp(-x3^2/2)*exp(-x27^2/2)*exp(x27*x3) out of the Partition would accomplish that.

JacquesCarette commented 7 years ago

This is actually quite tricky. Easier to see what's happening on the Maple side (i.e. an integral of a partition).

There are two ways to attach this:

  1. If the condition can be solved for x27 (Maple 18 cannot, but maybe 2017 can), then the integral can be pushed inwards into both branches.
  2. In either case, if the ratio of the 2 weights is independent of x27, then that can be pulled out. In other words, if the weights are aw(x27) and bw(x27), where a and b are free of x27, w(x27) should indeed be pulled out of the <|> and the partition.

In this case, 2 does not in fact hold. (the expression in the second branch has a exp(...) - 1 factor in it that spoils everything).

ccshan commented 7 years ago

Not only do we want to recognize a normal instead of lebesgue; we want to recognize the same x2 <~ normal(x3, 1) that went in (i.e., we want to roundtrip it).

yuriy0 commented 7 years ago

If the condition can be solved for x27 (Maple 18 cannot, but maybe 2017 can), then the integral can be pushed inwards into both branches.

It does indeed succeed in solving the condition (eventually, after a bit of preprocessing - in particular, a call to KB:-try_improve_exp). The result is

Or(And(0 < x3, x25 < -x3), And(0 < x3, x3 < x25), And(x25 < x3, x3 < 0), And(x3 < 0, -x3 < x25))

and this is thrown away, because it is so complex (bit of a lie: the actual result is even worse, but that's what we would get if we kept it and continued simplifying). But even if we kept this, split it into four pieces (or two with a bit more simplification), 'unflattened' the Partition so that it became nested Partitions with the conditions mentioning x25 innermost, pushed down the integral, and moved the conditions like x25 < x3 into the bounds of that integral, that would still be recognized as Lebesgue(-infinity,x3).

I realize the above result is not solved for x25; it is the result of solve with no variable arguments. solveing for the integration variable gives no solutions. I also tried it with SemiAlgebraic but that gives essentially the same solution, but with some of the steps mentioned above already performed (the conditions mentioning x25 are placed innermost).

Not only do we want to recognize a normal instead of lebesgue; we want to recognize the same x2 <~ normal(x3, 1) that went in (i.e., we want to roundtrip it).

if the weights are aw(x27) and bw(x27), where a and b are free of x27, w(x27) should indeed be pulled out of the <|> and the partition.

I can manage to recognize a Gaussian by manually pulling out a weight of exp(-1/2*x2^2) * exp(x2*x3) * exp(-1/2*x3^2) * 2^(-1/2) * (Pi^(-1/2)) after the call to improve. Call this weight W.

Here the condition of "free of x27" would cause the needed weight to not be pulled out. I think we need a stronger simplification - basically a factorization of Partition (Partition being essentially a sum of products of 'indicators' should lend itself well to this).

W is the largest weight which is common to both summands, which is seemingly a nice property, because we would not have to think too hard about which weight exactly to pull out - just pull out the largest weight possible.

Currently, Partition:-Simpl pushes products and sums inwards because recursive calls to Simpl expect to see a top-level Partition; this is needed for Simpl:-flatten to do its job. And in this case, we do perform a non-trivial flattening. More importantly, not pushing down weights (i.e. accomplishing flattening by some other method) wouldn't work, because the weight we need outermost is already distributed over a sum of Partitions. That sum of Partitions becomes a Partition of sums (due to Partition:-PProd) and the common weight is left distributed over two summands inside the Partition (w0 * h(x3) + w1 * h(x2), where W is present in both w0 and w1.) My initial thought was to prevent the needed weight from being pushed down in the first place, but that doesn't seem feasible, as it would have to be prevented in at least two known places (sum_assuming and Partition:-Simpl) and possibly very many other unknown places (for example, simplify_factor_assuming seems to push parts of the weight not 'obviously' present in both summands down).

At this point, it seemed that it should be easy to factor that expression - but I haven't figured out how. It seems the problem is that it isn't syntactically obvious that W appears in both w0 and w1. w0 contains W verbatim, but w1 contains exp(-13/25*x2^2) * exp(-12/25*x3^2). factor doesn't get it. I'm now looking for a Maple simplifier which can compute this factorization - perhaps that is factor combined with an appropriate normalization which allows it to see that W appears in both summands.

To summarize: is there an easy way to get Maple to factor

-(1/2)*2^(1/2)*(exp(-(1/50)*x25^2)*exp((1/50)*x3^2)-1)*exp(-(1/2)*x25^2)*exp(x25*x3)*exp(-(1/2)*x3^2)*h1(x3)/Pi^(1/2)+(1/2)*2^(1/2)*exp(-(13/25)*x25^2)*exp(x25*x3)*exp(-(12/25)*x3^2)*h1(x25)/Pi^(1/2)

into

(1/2)*exp(-(1/2)*x25^2)*exp(x25*x3)*exp(-(1/2)*x3^2)*2^(1/2)*(-(exp(-(1/50)*x25^2)*exp((1/50)*x3^2)-1)*h1(x3)+exp(-(1/50)*x25^2)*exp((1/50)*x3^2)*h1(x25))/Pi^(1/2)

(by easy, I mean without having to write code to identify the common factor W).

With this factorization performed, we would get

Bind(Gaussian(x3, 1), x25, 
   Partition(-(1/50)*x25^2+(1/50)*x3^2 <= 0, 
                Msum(Weight(exp(-(1/50)*x25^2)   *exp((1/50)*x3^2), Ret(x25)), 
                     Weight(1-exp(-(1/50)*x25^2) *exp((1/50)*x3^2), Ret(x3))), 
             0 < -(1/50)*x25^2+(1/50)*x3^2, 
                     Msum()))

which I think is about the best we can do.

JacquesCarette commented 7 years ago

If we take ww to be what you call W, then here is one way:

  1. for each exp(rat * monomial), convert that to fresh-name-of(monomial) ^ rat.
  2. simplify
  3. undo transformation from 1

In this example, you get

zz := eval(ww, [exp(-1/2*x25^2)=y1^(-1/2), exp(-1/2*x3^2)=y2^(-1/2), 
  exp(-1/50*x25^2)=y1^(-1/50), exp(1/50*x3^2) = y2^(1/50), 
  exp(-12/25*x3^2)=y2^(-12/25), exp(x25*x3)=y3, exp(-13/25*x25^2)=y1^(-13/25)]);

which gives

-(1/2)*2^(1/2)*(y2^(1/50)/y1^(1/50)-1)*y3*h1(x3)
 /(y1^(1/2)*y2^(1/2)*Pi^(1/2))+(1/2)*2^(1/2)*y3*h1(x25)/(y1^(13/25)*y2^(12/25)*Pi^(1/2))

then simplify gives

-(1/2)*2^(1/2)*y3*(-h1(x3)*y1^(1/50)+h1(x3)*y2^(1/50)-h1(x25)*y2^(1/50))/(y1^(13/25)*y2^(1/2)*Pi^(1/2))

which is close enough to what you want. Interestingly, simplify here works better than factor !

In general, you might want to assume that all your fresh names are >0 (since they are all exp(real)). In this case, that changes nothing, but in general it might.

JacquesCarette commented 7 years ago

In general, you can deal with exp(polynomial) by expanding it into a product of exp(monomial).

yuriy0 commented 7 years ago

I found something that sort of does what we want - it is convert(.., exp). The latest output is

fn x3 real:
weight
  (sqrt(26/1) * exp(x3 ^ 2 * (-649/1300)) * (5/26),
   x27 <~ normal(x3 * (+25/26), sqrt(26/1) * (5/26))
   if x27 ^ 2 * (-1/50) + x3 ^ 2 * (+1/50) <= +0/1:
     weight
       (real2prob
          (prob2real(exp(x3 ^ 2 * (+12/25)) * exp(x27 ^ 2 * (+1/50)))
           + prob2real(exp(x3 ^ 2 * (+1/2))) * (-1/1)),
        return x3) <|>
     weight(exp(x3 ^ 2 * (+1/2)), return x27)
   else: reject. measure(real))

Unfortunately, this doesn't roundtrip the normal; the weight it pulls out is not quite the correct one for that.

which is close enough to what you want.

Also unfortunately, that isn't quite close enough either. The normal still doesn't roundtrip; we get

Weight(-(5/26)*26^(1/2)*exp(-(1/52)*x3^2), Bind(Gaussian((25/26)*x3, (5/26)*26^(1/2)), x25, ...))

This seems to be entirely non-trivial; recognizing the correct weight to roundtrip the normal seems to require remembering what the original normal was, but that will hardly ever make sense since we eliminate integrals so often, and even when it does (in this case) the final term is so different from the original that I can't possibly see how to use the original information to deduce which weight it is we would like to factor out.

JacquesCarette commented 7 years ago

In the absence of a strong reason to roundtrip this particular measure in its 'simplified' form, I would say the current result is good enough. It does look really hard to 'guess' the right coordinates to use that would make this roundtrip in its nicest form (which seems to me to be what is going on).

yuriy0 commented 7 years ago

It does look really hard to 'guess' the right coordinates to use

Indeed. I have a vague belief that this should be fixed in the recognizer, since it is in a better position to 'know what it wants'.

ccshan commented 7 years ago

Is it too much to ask for tests/RoundTrip/testKernel.0.hk to roundtrip (I mean for it to become something of the form fn x3 real: x2 <~ normal(x3, 1) ...)? In the output of toLO on this measure, the normal(x3, 1) density is nicely delineated: outside an inert Sum and inside an inert Int. When we started moving from SLO to NewSLO back in 1578, we decided that simplification should not try to be too clever about reordering integrals (including summations).

yuriy0 commented 7 years ago

we decided that simplification should not try to be too clever about reordering integrals (including summations).

So that turns out to be the solution (not reordering weights over integrals in the first place, as opposed to trying to guess where they should go after they've been mangled)... I'm not sure why I didn't think of it before. It seems so obvious now!

The current output of simplifying both programs is

fn x3 real:
x27 <~ normal(x3, 1/1)
if x27 ^ 2 * (-1/50) + x3 ^ 2 * (+1/50) <= +0/1:
  weight
    (exp(x27 ^ 2 * (-1/50)) * exp(x3 ^ 2 * (+1/50)), return x27) <|>
  weight
    (real2prob
       (+1/1
        + prob2real(exp(x27 ^ 2 * (-1/50)))
          * prob2real(exp(x3 ^ 2 * (+1/50)))
          * (-1/1)),
     return x3)
else: reject. measure(real)
JacquesCarette commented 7 years ago

Fascinating. And this change does not cause lots of test failures?

yuriy0 commented 7 years ago

None at all. We will still push down weights, but only when we need to. And the change which fixed this (although since I've added more logic to not push down weights when it isn't strictly needed) was in Domain, which strictly doesn't care about weights at all. It doesn't even 'know' what should be done with those weights which are no longer pushed down; that happens here.