necludov / jam

Implementation of Action Matching
https://arxiv.org/abs/2210.06662
MIT License
36 stars 6 forks source link

about the time sampler #2

Closed boxaio closed 10 months ago

boxaio commented 10 months ago

in the file dynamics.utils.py, you define a time sampler to sample the time points between t=0 and t=1,

def sample_uniformly(bs, state): u = (state.u0 + math.sqrt(2)jnp.arange(bsjax.device_count())) % 1 new_state = state.replace(u0=u[-1:]) t = (t_1-t_0)u[jax.process_index()bs:(jax.process_index()+1)*bs] + t_0 return t, new_state

this function return a time point distribution that is not uniform, on what consideration did you choose this sampling function?