Open GrantMcConachie opened 4 months ago
I was also wondering this in the past. Another piece where numpy is used/can be used is batch
/batch_np
. IIRC the numpy version was much faster in some situations, hinting that there is some unwanted jit compilation happening when using jnp functions. That might also be the case if the numpy functions in pad_with_graphs
were replaced with jnp functions. To me it seems that in jax.numpy.sum some jit compiling is always happening, which is not what you want to happen if array sizes change. It would be nice to have some clarification on this though.
Hello! I was wondering if there is any particular reason that the
pad_with_graphs
function uses the numpy library rather than the jax.numpy library. It looks like every numpy function in there can just be replaced with jax.numpy without any issues, but I could be missing something.