google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.07k stars 642 forks source link

bias and kernel params are put on different gpu devices #4116

Open YunxiTang opened 2 months ago

YunxiTang commented 2 months ago

System information

Problem you have encountered:

When I try to initialize a Flax model on a specific gpu device (for example, gpu 1), the bias and kernel params are located on different gpu devices.

What you expected to happen:

The bias and kernel params should be put on the same gpu device.

Steps to reproduce:

  import jax
  import jax.numpy as jnp
  from jax import tree_util
  from flax import linen as nn

  device = jax.devices("gpu")[1]

  class MyModel(nn.Module):
      @nn.compact
      def __call__(self, x):
          x = nn.Conv(64, (3, 3), 1, name='conv1')(x)
          x = nn.relu(x)
          return x

  rng = jax.random.PRNGKey(0)
  rng = jax.device_put(rng, device)
  dummy_input = jax.device_put(jnp.ones((5, 64, 64, 32)), device) 

  model = MyModel()  
  model_params = model.init({'params': rng}, dummy_input)
  # model_params = tree_util.tree_map(lambda x: jax.device_put(x, device), model_params)
  print(tree_util.tree_map(lambda x: (x.device()), model_params))

The output is

FrozenDict({
    params: {
        conv1: {
            bias: gpu(id=0),
            kernel: gpu(id=1),
        },
    },
})

Thanks!

MasterSkepticista commented 1 month ago

Hi @YunxiTang, I am able to reproduce this issue.

In practice, I have seen flax models initialized on cpu, and migrated/replicated to devices later. Two examples:

  1. Migrating params post-initialization to GPU.
    # Optional: Init on `cpu`.
    model_params = jax.jit(model.init, backend="cpu")({'params': rng}, dummy_input)
    model_params = jax.device_put(model_params, device)
    jax.tree.map(lambda p: p.device, model_params)
    # {'params': {'conv1': {'bias': CudaDevice(id=1), 'kernel': CudaDevice(id=1)}}}
  2. Using jax.default_device scope.
    with jax.default_device(device):
        model_params = model.init({'params': rng}, dummy_input)
        print(tree_util.tree_map(lambda x: (x.device), model_params))
        # {'params': {'conv1': {'bias': CudaDevice(id=1), 'kernel': CudaDevice(id=1)}}}

I will let Flax team comment on the default behavior in your case.