issues
search
rdyro
/
torch2jax
Wraps PyTorch code in a JIT-compatible way for JAX. Supports automatically defining gradients for reverse-mode AutoDiff.
https://rdyro.github.io/torch2jax/
MIT License
37
stars
1
forks
source link
issues
Newest
Newest
Most commented
Recently updated
Oldest
Least commented
Least recently updated
Donate Args Causing Repeated Spawning of Warning
#19
adam-hartshorne
opened
1 day ago
0
Multi-GPU Question
#18
adam-hartshorne
opened
4 days ago
2
JAX v0.4.35 - xla_client.register_custom_call_target has been deprecated
#17
adam-hartshorne
closed
1 week ago
1
New JAX Functionality For Calling "Foreign" Code
#16
adam-hartshorne
closed
1 month ago
5
CUDA bug and JAX breaking change
#15
zwei-beiner
closed
2 months ago
4
gcc version?
#14
rkruegs123
closed
5 months ago
2
Minor API change
#13
zwei-beiner
closed
6 months ago
4
Multi-gpu support
#12
zwei-beiner
closed
6 months ago
23
DLPack API Changes
#11
adam-hartshorne
closed
5 months ago
2
Use of new JAX Performance Flags?
#10
adam-hartshorne
closed
1 year ago
2
Handling situation of having two identically named packages
#9
adam-hartshorne
closed
11 months ago
3
Any thoughts on this torch2jax alternative?
#8
adam-hartshorne
closed
1 year ago
1
custom_call() args need to be updated to remove out_types
#7
danielpmorton
closed
1 year ago
3
JAX / XLA adding the importing external dlpack-aware Python arrays.
#6
adam-hartshorne
opened
1 year ago
1
Ignore
#5
adam-hartshorne
closed
1 year ago
0
New bug when optimising function
#4
adam-hartshorne
closed
1 year ago
4
Issue when attempting to optimise using torch2jax function
#3
adam-hartshorne
closed
1 year ago
8
Quick question about Overview Code Examples
#2
adam-hartshorne
closed
1 year ago
3
Great Work - Question about Gradients.
#1
adam-hartshorne
closed
1 year ago
6