choderalab / modelforge

Infrastructure to implement and train NNPs
https://modelforge.readthedocs.io/en/latest/
MIT License
9 stars 4 forks source link

Wrap PyTorch model in JAX/Flax #101

Closed wiederm closed 2 months ago

wiederm commented 2 months ago

Description

We want to use the trained models in different production environments and backends. A first step in this direction is to wrap the trained PyTorch model in a JAX function and use dlpack and zero-copy to map between JAX and PyTorch tensors on the used device. This allows using a PyTorch model with JAX's autograd—i.e., forward calls are in PyTorch, and backward calls are in JAX.

In general, we have two scenarios for the PyTorch- to JAX conversion:

  1. wrap a trained model in a Flax model. In this scenario, the modelforge pairlist calculation is used and the input signature for the wrapped model is ($\vec{R}$, Z, Q)
  2. convert a trained model to a Flax model. This scenario will use the chiron pairlist and the input signature for the converted model ($\vec{r}_ij$, $d_ij$, Z, Q)

This PR will implement scenario (1). This is analogous to using any pre-trained model that is provided by third parties (e.g., the MACE model), which takes derivatives with respect to coordinates and, therefore, uses coordinates as input. and the input signature ($\vec{R}$, Z, Q)

Todos

Status

wiederm commented 2 months ago

Thanks for the review @chrisiacovella ! Since we have restructure the models in a core class that we want to export I will need to refactor this PR a bit. I will ping you as soon as it is ready.