Trax seems to reserve all available GPU memory. After just defining and initializing an Embedding layer of size 1000x64, almost all of the available GPU memory I have is gone to the Python process running Trax.
nvidia-smi
Tue Dec 7 14:48:14 2021
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44 Driver Version: 495.44 CUDA Version: 11.5 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA GeForce ... Off | 00000000:01:00.0 On | N/A |
| N/A 54C P8 14W / N/A | 5605MiB / 5934MiB | 7% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 2650 G /usr/libexec/Xorg 330MiB |
| 0 N/A N/A 2812 G /usr/bin/gnome-shell 207MiB |
| 0 N/A N/A 62099 G ...AAAAAAAAA= --shared-files 15MiB |
| 0 N/A N/A 73704 G ...AAAAAAAAA= --shared-files 60MiB |
| 0 N/A N/A 78782 G /usr/lib64/firefox/firefox 110MiB |
| 0 N/A N/A 79092 G /usr/lib64/firefox/firefox 1MiB |
| 0 N/A N/A 81544 C ...nda3/envs/trax/bin/python 4859MiB |
| 0 N/A N/A 81902 G /usr/lib64/firefox/firefox 1MiB |
+-----------------------------------------------------------------------------+
Environment information
OS: Fedora 35, Linux Kernel 5.15.6-200.fc35.x86_64
Description
Trax seems to reserve all available GPU memory. After just defining and initializing an
Embedding
layer of size 1000x64, almost all of the available GPU memory I have is gone to the Python process running Trax.Environment information
OS: Fedora 35, Linux Kernel 5.15.6-200.fc35.x86_64
$ pip freeze | grep trax
$ pip freeze | grep tensor
$ pip freeze | grep jax
$ python -V
Steps to reproduce:
Follow the procedure to install
jax
withCUDA
, which is simply:Install Trax with:
Start Python and execute the following statements:
Execute
nvidia-smi
and check the GPU memory consumed by your Python process.Error logs:
N/A