google-deepmind / jraph

A Graph Neural Network Library in Jax
https://jraph.readthedocs.io/en/latest/
Apache License 2.0
1.37k stars 89 forks source link

pad_with_graphs written with numpy #52

Open GrantMcConachie opened 4 months ago

GrantMcConachie commented 4 months ago

Hello! I was wondering if there is any particular reason that the pad_with_graphs function uses the numpy library rather than the jax.numpy library. It looks like every numpy function in there can just be replaced with jax.numpy without any issues, but I could be missing something.

tisabe commented 3 months ago

I was also wondering this in the past. Another piece where numpy is used/can be used is batch/batch_np. IIRC the numpy version was much faster in some situations, hinting that there is some unwanted jit compilation happening when using jnp functions. That might also be the case if the numpy functions in pad_with_graphs were replaced with jnp functions. To me it seems that in jax.numpy.sum some jit compiling is always happening, which is not what you want to happen if array sizes change. It would be nice to have some clarification on this though.