-
The parallelism documentation is mostly focused on multiple TPU/GPU devices. It would be great if jax team add concrete explanations on how jax uses parallelism when the runtime system is composed of …
-
## ❓ Questions and Help
I have a request to make the pytorch input model in NCHW format by default, and convert it to HWOI format during the training process, which is conducive to hardware processin…
-
## ❓ Questions and Help
Hi :
With the user guide [https://github.com/pytorch/xla/blob/master/CONTRIBUTING.md#build-from-source](url)
I have built pytorch successfully, but when i build xla…
-
What is the recommended way of including the XLA compiler as a library in a CMake build system?
There seems to be no Bazel target to install/package the libraries an include files.
I am looking fo…
-
Hi - trying to run either the 360 or raw scripts (with the paths suitably edited) leaves me in an endless loop as below. I take the JAX warnings not to be errors (I get the same running with CPU or GP…
-
### What is the feature?
Support mmlab training on the AWS Trainium device
### Any other context?
- AWS [recently announced general availability of Trainium instances](https://aws.amazon.com/…
-
On v0.1.25 on OSX, I get the following error when computing gradients from the following jit-compiled function.
```python
import numpy as onp
import jax.numpy as np
from jax import grad, jit
…
-
## 🐛 Bug
## To Reproduce
I got latest 11.2 drivers and updated required libs
```
22:51 $ conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing_extensions future six…
-
## 🚀 Feature
Currently PyTorch/XLA uses `xla::shape` all over the place. Common use of `xla::shape` would be to get the number of elements of a tensor, compare shape equally between two tensors, chec…
-
## 🐛 Bug
I started to observe the following error when calling ./run_tests.sh locally. Has anyone else seen it? This is happening on the latest master of pytorch and pytorch/xla. The stack trace go…