Closed LennertEvens closed 6 months ago
Hello, Why would you trace a Jax module with PyTorch?
Hello, Why would you trace a Jax module with PyTorch?
The goal is to use the traced model for inference in C/C++ applications. The significant speedup during training is a huge advantage of SBX over SB3.
The goal is to use the traced model for inference in C/C++ applications. The significant speedup during training is a huge advantage of SBX over SB3.
Then you need to use ONNX with Jax. Apparently, you need to convert it first to TF: https://github.com/google/jax/issues/7629#issuecomment-898939109
Otherwise, you need to manually re-create the policy architecture in PyTorch and load the exported weights into it.
🐛 Bug
When using PyTorch JIT to trace and save a trained model with SBX an exception occurs.
To Reproduce
The following code works fine for a model trained with TD3 with SB3. However, a TypeError occurs when trying to save a model trained with SBX.
System Info
Checklist