Closed yhtang closed 2 years ago
PyTorch:
ff.use('torch', device'cuda')
JAX
ff.use('jax', enable_x64=True)
PyTorch:
JAX