rdyro / torch2jax

Wraps PyTorch code in a JIT-compatible way for JAX. Supports automatically defining gradients for reverse-mode AutoDiff.
https://rdyro.github.io/torch2jax/
MIT License
37 stars 1 forks source link

custom_call() args need to be updated to remove out_types #7

Closed danielpmorton closed 1 year ago

danielpmorton commented 1 year ago

Jax recently updated some of the parameters in the custom_call() function, so any calls to this need to be updated. For reference: https://github.com/google/jax/commit/24f9011d49f974fa24613206c5af774089ecf346

It seems that this may be an easy change - for instance, changing out_types to result_types. But, I'm not a Jax expert, so maybe check the linked commit to see if this works

danielpmorton commented 1 year ago

Ah, Steph mentions I should keep Jax at 0.4.8

I can close this issue (or leave it open) -- what do you think?

rdyro commented 1 year ago

Ah, perfect, a really good catch. I'll update the package soon.

I believe the current version should be compatible with JAX <= 4.13

This requires a bit of versioning work to make both <= 4.13 and > 4.13 automatically detectable, so I'll get to it soon.

I'll leave this issue open for now.

rdyro commented 1 year ago

Resolved without breaking compatibility in the lastest commit bb328646a83107c184b10ab1717693e9ef3000b6