google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.79k stars 609 forks source link

Not sure if flax use the GPU #3979

Closed hndrbrm closed 2 weeks ago

hndrbrm commented 3 weeks ago

I am new on flax and jax. Have try the firstime from the guide. I confirm that the Jax really did use the GPU, nvidia_smi shows 95% percentage. While on flax, i see only 1%. Maybe theres step that i missed?

I use the basic guide from the documentation:

Code ```python import flax import jax from flax import linen as nn from jax import random, numpy as jnp # We create one dense layer instance (taking 'features' parameter as input) model = nn.Dense(features=5) key1, key2 = random.split(random.key(0)) x = random.normal(key1, (10,)) # Dummy input data params = model.init(key2, x) # Initialization call jax.tree_util.tree_map(lambda x: x.shape, params) # Checking output shapes model.apply(params, x) # Set problem dimensions. n_samples = 20 x_dim = 10 y_dim = 5 # Generate random ground truth W and b. key = random.key(0) k1, k2 = random.split(key) W = random.normal(k1, (x_dim, y_dim)) b = random.normal(k2, (y_dim,)) # Store the parameters in a FrozenDict pytree. true_params = flax.core.freeze({'params': {'bias': b, 'kernel': W}}) # Generate samples with additional noise. key_sample, key_noise = random.split(k1) x_samples = random.normal(key_sample, (n_samples, x_dim)) y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise, (n_samples, y_dim)) print('x shape:', x_samples.shape, '; y shape:', y_samples.shape) @jax.jit def mse(params, x_batched, y_batched): # Define the squared loss for a single pair (x,y) def squared_error(x, y): pred = model.apply(params, x) return jnp.inner(y - pred, y - pred) / 2.0 # Vectorize the previous to compute the average of the loss on all samples. return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0) learning_rate = 0.3 # Gradient step size. print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples)) loss_grad_fn = jax.value_and_grad(mse) @jax.jit def update_params(params, learning_rate, grads): params = jax.tree_util.tree_map( lambda p, g: p - learning_rate * g, params, grads) return params for i in range(10000): # Perform one gradient update. loss_val, grads = loss_grad_fn(params, x_samples, y_samples) params = update_params(params, learning_rate, grads) if i % 10 == 0: print(f'Loss step {i}: ', loss_val) ```

And check using nvidia_smi shows its only using 1% of the GPU:

GPU ``` Mon Jun 10 13:12:01 2024 +---------------------------------------------------------------------------------------+ | NVIDIA-SMI 535.171.04 Driver Version: 535.171.04 CUDA Version: 12.2 | |-----------------------------------------+----------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+======================+======================| | 0 NVIDIA GeForce RTX 3080 Off | 00000000:01:00.0 Off | N/A | | 54% 42C P2 93W / 320W | 7776MiB / 10240MiB | 1% Default | | | | N/A | +-----------------------------------------+----------------------+----------------------+ ```

Previously i use from jax documentation on the example:

Code ```python import timeit import jax.numpy as jnp import jax.random as random from jax import jit def selu(x, alpha=1.67, lmbda=1.05): return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) key = random.key(0) x = random.normal(key, (1_000_000_00,)) print(timeit.timeit('''for i in range(100): selu(x).block_until_ready()''', globals=globals(), number=10)) selu_jit = jit(selu) print(timeit.timeit('''for i in range(100): selu_jit(x).block_until_ready()''', globals=globals(), number=10)) ```

Using nvidia_smi it shows 95% usage:

GPU ``` +---------------------------------------------------------------------------------------+ | NVIDIA-SMI 535.171.04 Driver Version: 535.171.04 CUDA Version: 12.2 | |-----------------------------------------+----------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+======================+======================| | 0 NVIDIA GeForce RTX 3080 Off | 00000000:01:00.0 Off | N/A | | 0% 49C P2 273W / 320W | 7768MiB / 10240MiB | 95% Default | | | | N/A | +-----------------------------------------+----------------------+----------------------+ ```
IvyZX commented 2 weeks ago

Flax only interacts with lower level via JAX, so if JAX can access GPU, Flax can. You can print jax.devices() to verify that.

In your case though, I don't see any code that puts your params and input data to GPU and let it be sharded. You might want to:

Also check out this jax guide on distributed arrays: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html

hndrbrm commented 2 weeks ago

Flax only interacts with lower level via JAX, so if JAX can access GPU, Flax can. You can print jax.devices() to verify that.

In your case though, I don't see any code that puts your params and input data to GPU and let it be sharded. You might want to:

  • check the .sharding of your JAX arrays and what jax devices they are attached to
  • use jax.device_put to put your params onto GPU devices
  • configure jax.jit args like in_shardings to specify computations to happen on GPU.

Also check out this jax guide on distributed arrays: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html

Thanks for the reply, lot of keywords i still don't know reading your response.

The jax definitely can access, i put the SS, the nvidia-smi shows around 95% usage of gpu. For sharding and device_put, this is new for me, i am gonna read about it.