google / trax

Trax — Deep Learning with Clear Code and Speed
Apache License 2.0
8.07k stars 813 forks source link

Trax Reserves All Available GPU Memory #1722

Open rafidka opened 2 years ago

rafidka commented 2 years ago

Description

Trax seems to reserve all available GPU memory. After just defining and initializing an Embedding layer of size 1000x64, almost all of the available GPU memory I have is gone to the Python process running Trax.

nvidia-smi
Tue Dec  7 14:48:14 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 495.44       CUDA Version: 11.5     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  Off  | 00000000:01:00.0  On |                  N/A |
| N/A   54C    P8    14W /  N/A |   5605MiB /  5934MiB |      7%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A      2650      G   /usr/libexec/Xorg                 330MiB |
|    0   N/A  N/A      2812      G   /usr/bin/gnome-shell              207MiB |
|    0   N/A  N/A     62099      G   ...AAAAAAAAA= --shared-files       15MiB |
|    0   N/A  N/A     73704      G   ...AAAAAAAAA= --shared-files       60MiB |
|    0   N/A  N/A     78782      G   /usr/lib64/firefox/firefox        110MiB |
|    0   N/A  N/A     79092      G   /usr/lib64/firefox/firefox          1MiB |
|    0   N/A  N/A     81544      C   ...nda3/envs/trax/bin/python     4859MiB |
|    0   N/A  N/A     81902      G   /usr/lib64/firefox/firefox          1MiB |
+-----------------------------------------------------------------------------+

Environment information

OS: Fedora 35, Linux Kernel 5.15.6-200.fc35.x86_64

$ pip freeze | grep trax

trax==1.4.1

$ pip freeze | grep tensor

tensorboard==2.7.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorflow==2.7.0
tensorflow-datasets==4.4.0
tensorflow-estimator==2.7.0
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.22.0
tensorflow-metadata==1.5.0
tensorflow-text==2.7.3

$ pip freeze | grep jax

jax==0.2.25
jaxlib==0.1.73+cuda11.cudnn82

$ python -V

Python 3.9.7

Steps to reproduce:

Follow the procedure to install jax with CUDA, which is simply:

pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

Install Trax with:

pip install -U trax

Start Python and execute the following statements:

import numpy as np
import trax

emb = trax.layers.Embedding(1000, 64)
emb.init(None)

Execute nvidia-smi and check the GPU memory consumed by your Python process.

Error logs:

N/A

jsearcy1 commented 2 years ago

You might want to take a look at this: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html

rafidka commented 2 years ago

Thanks. This explains what is happening. I think this should be mentioned in the Trax documentation as well.