blackjax-devs / blackjax

BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
https://blackjax-devs.github.io/blackjax/
Apache License 2.0
806 stars 105 forks source link

Bug example #688

Closed reubenharry closed 3 months ago

reubenharry commented 4 months ago

What appear to be a bug with jax (or at least a subtle case). A change of map between pmap and vmap in bug.py changes the printed result.

I also saw other very strange behaviors for this code that I have not yet reproduced in this example, where a print statement resulted in a similar change of value.

reubenharry commented 4 months ago

The mistake can be traced to streaming_average_update, but I don't know exactly why. For example, if we replace streaming_average_update by:

def streaming_average_update(
    expectation, streaming_avg, weight=1.0, zero_prevention=0.0
): return streaming_avg

The discrepancy disappears (at least in this case)

junpenglao commented 4 months ago

I suspect that the bug is due to the multi-chain and vmap/pmap. Bascially ravel_pytree is flattening across chain for some reason.

Nonetheless, could you try the below:

def streaming_average_update(
    current_value, previous_weight_and_average, weight=1.0, zero_prevention=0.0
):
    """Compute the streaming average of a function O(x) using a weight.
    Parameters:
    ----------
        current_value
            the current value of the function that we want to take average of
        previous_weight_and_average
            tuple of (previous_weight, previous_average) where previous_weight is the
            sum of weights and average is the current estimated average
        weight
            weight of the current state
        zero_prevention
            small value to prevent division by zero
    Returns:
    ----------
        new total weight and streaming average
    """
    previous_weight, previous_average = previous_weight_and_average
    current_weight = previous_weight + weight
    current_average = jax.tree.map(
        lambda x, avg: (previous_weight * avg + weight * x)
        / (current_weight + zero_prevention),
        current_value,
        previous_average,
    )
    return current_weight, current_average
reubenharry commented 4 months ago

This doesn't affect the results. I don't think the pytree is the issue, since I also see the problem with:

def streaming_average_update(
    expectation, streaming_avg, weight=1.0, zero_prevention=0.0
):
    total, average = streaming_avg
    average = (total * average + weight * expectation) / (
        total + weight + zero_prevention
    )
    total += weight
    streaming_avg = (total, (average))
    return streaming_avg
reubenharry commented 4 months ago

Actually even with:

def streaming_average_update(
    expectation, streaming_avg, weight=1.0, zero_prevention=0.0
):

    return streaming_avg

I get a (small) discrepancy:

Result with <function pmap at 0x11f3b40e0> is [[ 1.2639244  -0.19290113]]
Result with <function vmap at 0x1224eff60> is [[ 1.2648091  -0.19346505]]
reubenharry commented 4 months ago

For a batch of 2 instead of 1, I also see a discrepancy, although it is small. I wonder if there is a key issue involved:

Result with <function pmap at 0x127feff60> is [[ 0.13665608 -1.9322048 ]
 [ 1.276183   -0.18907894]]

Result with <function vmap at 0x12ff97f60> is [[ 0.13712421 -1.9322233 ]
 [ 1.2760496  -0.18907525]]
reubenharry commented 4 months ago

The other odd thing worth mentioning is that the errors often get less bad with longer runs and higher dimensions, e.g.: 10D Gaussian with 10000 steps

Result with <function pmap at 0x165fc8180> is [[ 0.8321227  -2.5019531  -0.7922016  -1.3899261   1.2677195  -0.35655203
   0.48542497 -0.03980938  0.14613445 -1.3119547 ]
 [-0.03221251  0.5771989  -0.23249874 -0.62638044  0.9548867  -1.3366085
  -0.1022618   0.6617969   0.9869026  -1.4713044 ]]

Result with <function vmap at 0x11d3abe20> is [[ 0.8321227  -2.501953   -0.79220164 -1.389926    1.2677199  -0.35655132
   0.4854263  -0.03980982  0.14613375 -1.3119547 ]
 [-0.03221169  0.5772002  -0.2324989  -0.6263806   0.9548884  -1.3366088
  -0.1022612   0.6617972   0.9869019  -1.4713038 ]]