Open YunxiTang opened 2 months ago
Hi @YunxiTang, I am able to reproduce this issue.
In practice, I have seen flax models initialized on cpu
, and migrated/replicated to devices later. Two examples:
# Optional: Init on `cpu`.
model_params = jax.jit(model.init, backend="cpu")({'params': rng}, dummy_input)
model_params = jax.device_put(model_params, device)
jax.tree.map(lambda p: p.device, model_params)
# {'params': {'conv1': {'bias': CudaDevice(id=1), 'kernel': CudaDevice(id=1)}}}
jax.default_device
scope.
with jax.default_device(device):
model_params = model.init({'params': rng}, dummy_input)
print(tree_util.tree_map(lambda x: (x.device), model_params))
# {'params': {'conv1': {'bias': CudaDevice(id=1), 'kernel': CudaDevice(id=1)}}}
I will let Flax team comment on the default behavior in your case.
System information
Problem you have encountered:
When I try to initialize a Flax model on a specific gpu device (for example,
gpu 1
), the bias and kernel params are located on different gpu devices.What you expected to happen:
The bias and kernel params should be put on the same gpu device.
Steps to reproduce:
The output is
Thanks!