Closed hndrbrm closed 2 weeks ago
Flax only interacts with lower level via JAX, so if JAX can access GPU, Flax can. You can print jax.devices()
to verify that.
In your case though, I don't see any code that puts your params and input data to GPU and let it be sharded. You might want to:
.sharding
of your JAX arrays and what jax devices they are attached tojax.device_put
to put your params onto GPU devicesjax.jit
args like in_shardings
to specify computations to happen on GPU.Also check out this jax guide on distributed arrays: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html
Flax only interacts with lower level via JAX, so if JAX can access GPU, Flax can. You can print
jax.devices()
to verify that.In your case though, I don't see any code that puts your params and input data to GPU and let it be sharded. You might want to:
- check the
.sharding
of your JAX arrays and what jax devices they are attached to- use
jax.device_put
to put your params onto GPU devices- configure
jax.jit
args likein_shardings
to specify computations to happen on GPU.Also check out this jax guide on distributed arrays: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html
Thanks for the reply, lot of keywords i still don't know reading your response.
The jax definitely can access, i put the SS, the nvidia-smi shows around 95% usage of gpu. For sharding and device_put, this is new for me, i am gonna read about it.
I am new on flax and jax. Have try the firstime from the guide. I confirm that the Jax really did use the GPU, nvidia_smi shows 95% percentage. While on flax, i see only 1%. Maybe theres step that i missed?
I use the basic guide from the documentation:
Code
```python import flax import jax from flax import linen as nn from jax import random, numpy as jnp # We create one dense layer instance (taking 'features' parameter as input) model = nn.Dense(features=5) key1, key2 = random.split(random.key(0)) x = random.normal(key1, (10,)) # Dummy input data params = model.init(key2, x) # Initialization call jax.tree_util.tree_map(lambda x: x.shape, params) # Checking output shapes model.apply(params, x) # Set problem dimensions. n_samples = 20 x_dim = 10 y_dim = 5 # Generate random ground truth W and b. key = random.key(0) k1, k2 = random.split(key) W = random.normal(k1, (x_dim, y_dim)) b = random.normal(k2, (y_dim,)) # Store the parameters in a FrozenDict pytree. true_params = flax.core.freeze({'params': {'bias': b, 'kernel': W}}) # Generate samples with additional noise. key_sample, key_noise = random.split(k1) x_samples = random.normal(key_sample, (n_samples, x_dim)) y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise, (n_samples, y_dim)) print('x shape:', x_samples.shape, '; y shape:', y_samples.shape) @jax.jit def mse(params, x_batched, y_batched): # Define the squared loss for a single pair (x,y) def squared_error(x, y): pred = model.apply(params, x) return jnp.inner(y - pred, y - pred) / 2.0 # Vectorize the previous to compute the average of the loss on all samples. return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0) learning_rate = 0.3 # Gradient step size. print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples)) loss_grad_fn = jax.value_and_grad(mse) @jax.jit def update_params(params, learning_rate, grads): params = jax.tree_util.tree_map( lambda p, g: p - learning_rate * g, params, grads) return params for i in range(10000): # Perform one gradient update. loss_val, grads = loss_grad_fn(params, x_samples, y_samples) params = update_params(params, learning_rate, grads) if i % 10 == 0: print(f'Loss step {i}: ', loss_val) ```And check using
nvidia_smi
shows its only using 1% of the GPU:GPU
``` Mon Jun 10 13:12:01 2024 +---------------------------------------------------------------------------------------+ | NVIDIA-SMI 535.171.04 Driver Version: 535.171.04 CUDA Version: 12.2 | |-----------------------------------------+----------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+======================+======================| | 0 NVIDIA GeForce RTX 3080 Off | 00000000:01:00.0 Off | N/A | | 54% 42C P2 93W / 320W | 7776MiB / 10240MiB | 1% Default | | | | N/A | +-----------------------------------------+----------------------+----------------------+ ```Previously i use from jax documentation on the example:
Code
```python import timeit import jax.numpy as jnp import jax.random as random from jax import jit def selu(x, alpha=1.67, lmbda=1.05): return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) key = random.key(0) x = random.normal(key, (1_000_000_00,)) print(timeit.timeit('''for i in range(100): selu(x).block_until_ready()''', globals=globals(), number=10)) selu_jit = jit(selu) print(timeit.timeit('''for i in range(100): selu_jit(x).block_until_ready()''', globals=globals(), number=10)) ```Using
nvidia_smi
it shows 95% usage:GPU
``` +---------------------------------------------------------------------------------------+ | NVIDIA-SMI 535.171.04 Driver Version: 535.171.04 CUDA Version: 12.2 | |-----------------------------------------+----------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+======================+======================| | 0 NVIDIA GeForce RTX 3080 Off | 00000000:01:00.0 Off | N/A | | 0% 49C P2 273W / 320W | 7768MiB / 10240MiB | 95% Default | | | | N/A | +-----------------------------------------+----------------------+----------------------+ ```