proteneer / timemachine

Differentiate all the things!
Other
139 stars 17 forks source link

Spurious calls to execute_selective in Jax wrapper #706

Closed maxentile closed 2 years ago

maxentile commented 2 years ago

Due to inefficient choices I made way back in https://github.com/proteneer/timemachine/pull/429 https://github.com/proteneer/timemachine/blob/836aeac2aa1a3eac244682cb603cdc7f38fd66a3/timemachine/fe/functional.py#L15-L43

  1. the Jax wrappers make multiple calls to impl.execute_selective whenever at least one derivative is needed
  2. when multiple derivatives are needed, these are not collapsed into a single call to impl.execute_selective

See this notebook for details: https://gist.github.com/maxentile/0759c025f9d141060741c7520c54d960

>>> # fine: 1 call to execute selective, as expected
>>> _ = U_fast(coords, sys_params, box, lam)
calling execute_selective(..., u)
>>> # problem: spurious calls to request u, dudp, dudl
>>> _ = grad(U_fast, argnums=0)(coords, sys_params, box, lam)
calling execute_selective(..., u)
calling execute_selective(..., dudx)
calling execute_selective(..., dudp)
calling execute_selective(..., dudl)
>>> # problem: not collapsed into execute_selective(..., u, dudx, dudp, dudl)
>>> _ = value_and_grad(U_fast, argnums=(0, 1, 3))(coords, sys_params, box, lam)
calling execute_selective(..., u)
calling execute_selective(..., dudx)
calling execute_selective(..., dudp)
calling execute_selective(..., dudl)

Note: this is a source of inefficiency in both the Python loop version construct_differentiable_interface and the SummedPotential version construct_differentiable_interface_fast. (In the Python loop version, the total number of calls is multiplied by len(unbound_potentials).)

maxentile commented 2 years ago

WIP fix in https://github.com/proteneer/timemachine/pull/707 avoids the spurious calls, and collapses multiple calls: https://gist.github.com/maxentile/1463655fea7900317e7c0bc62c154a46 . ~(However, there's a remaining issue with param packing / unpacking that affects the SummedPotential version when du_dp is requested...)~