Closed reubenharry closed 3 months ago
The mistake can be traced to streaming_average_update
, but I don't know exactly why. For example, if we replace streaming_average_update
by:
def streaming_average_update(
expectation, streaming_avg, weight=1.0, zero_prevention=0.0
): return streaming_avg
The discrepancy disappears (at least in this case)
I suspect that the bug is due to the multi-chain and vmap/pmap. Bascially ravel_pytree is flattening across chain for some reason.
Nonetheless, could you try the below:
def streaming_average_update(
current_value, previous_weight_and_average, weight=1.0, zero_prevention=0.0
):
"""Compute the streaming average of a function O(x) using a weight.
Parameters:
----------
current_value
the current value of the function that we want to take average of
previous_weight_and_average
tuple of (previous_weight, previous_average) where previous_weight is the
sum of weights and average is the current estimated average
weight
weight of the current state
zero_prevention
small value to prevent division by zero
Returns:
----------
new total weight and streaming average
"""
previous_weight, previous_average = previous_weight_and_average
current_weight = previous_weight + weight
current_average = jax.tree.map(
lambda x, avg: (previous_weight * avg + weight * x)
/ (current_weight + zero_prevention),
current_value,
previous_average,
)
return current_weight, current_average
This doesn't affect the results. I don't think the pytree is the issue, since I also see the problem with:
def streaming_average_update(
expectation, streaming_avg, weight=1.0, zero_prevention=0.0
):
total, average = streaming_avg
average = (total * average + weight * expectation) / (
total + weight + zero_prevention
)
total += weight
streaming_avg = (total, (average))
return streaming_avg
Actually even with:
def streaming_average_update(
expectation, streaming_avg, weight=1.0, zero_prevention=0.0
):
return streaming_avg
I get a (small) discrepancy:
Result with <function pmap at 0x11f3b40e0> is [[ 1.2639244 -0.19290113]]
Result with <function vmap at 0x1224eff60> is [[ 1.2648091 -0.19346505]]
For a batch of 2 instead of 1, I also see a discrepancy, although it is small. I wonder if there is a key issue involved:
Result with <function pmap at 0x127feff60> is [[ 0.13665608 -1.9322048 ]
[ 1.276183 -0.18907894]]
Result with <function vmap at 0x12ff97f60> is [[ 0.13712421 -1.9322233 ]
[ 1.2760496 -0.18907525]]
The other odd thing worth mentioning is that the errors often get less bad with longer runs and higher dimensions, e.g.: 10D Gaussian with 10000 steps
Result with <function pmap at 0x165fc8180> is [[ 0.8321227 -2.5019531 -0.7922016 -1.3899261 1.2677195 -0.35655203
0.48542497 -0.03980938 0.14613445 -1.3119547 ]
[-0.03221251 0.5771989 -0.23249874 -0.62638044 0.9548867 -1.3366085
-0.1022618 0.6617969 0.9869026 -1.4713044 ]]
Result with <function vmap at 0x11d3abe20> is [[ 0.8321227 -2.501953 -0.79220164 -1.389926 1.2677199 -0.35655132
0.4854263 -0.03980982 0.14613375 -1.3119547 ]
[-0.03221169 0.5772002 -0.2324989 -0.6263806 0.9548884 -1.3366088
-0.1022612 0.6617972 0.9869019 -1.4713038 ]]
What appear to be a bug with jax (or at least a subtle case). A change of
map
betweenpmap
andvmap
inbug.py
changes the printed result.I also saw other very strange behaviors for this code that I have not yet reproduced in this example, where a print statement resulted in a similar change of value.