Open sck-at-ucy opened 4 months ago
The code should definitely still run on the GPU. Its possible you are being bottlenecked by communication latency so it looks like the GPU is not used. One thing to check is that you are using the GPU in the same code even without doing any MPI communications (e.g. no all reduce or gather). In that case it should run fast on the GPU. If that's not the case, then I would double check even the single process setup.
If I run the same code as is locally in PyCharm and not on the command line through MPI, it shows steady 96% GPU utilization. When I run it through MPI, GPU utilization on both machines hovers around 15% with +/- 3%.
Yea exactly, so it looks like either communication latency or bandwidth is a bottleneck for you. Did you see this section in the docs on tuning all reduce?
In general there's a few things to keep in mind here:
Thank you, yes it makes sense. The example I have tried was very crude, I guess I was impatient to try it 🤓, with no batching so I think all reduce was called too frequently. I will refine to minimize the frequency of communication.
Great, keep us posted on how it goes!
Will do.
After implementing training in batches I get good speeds but I have run into a problem I do not understand. The issue surfaces when I try to print the loss. If I run the code in a single process without MPI it works fine. However, if I try to do that in distributed model, the code hangs and never completes. I've tried various variations but none worked. I also think it might be useful to implement an mx.distrtibuted.finalize()
or world.finalize()
functionality for clean termination.
# We can also force evaluate all parameters to initialize the model
mx.eval(model.parameters())
optimizer = optim.Adam(learning_rate=0.0001)
# A simple loss function.
def l2_loss(model, x, y):
y_hat = model(x)
return mx.array(y_hat - y).square().mean()
def all_reduce_grads(grads, N):
if N == 1:
return grads
return tree_map(
lambda x: mx.distributed.all_sum(x) / N,
grads)
def step(model, x, y):
loss_and_grad_fn = nn.value_and_grad(model, l2_loss)
loss, grads = loss_and_grad_fn(model, x, y)
return loss, grads
# Training loop with mini-batch processing
batch_size = 10
num_batches = num_samples_per_process // batch_size
for epoch in range(20): # Number of epochs
epoch_loss = mx.array([0.0])
for i in range(num_batches):
x_batch = X[i*batch_size:(i+1)*batch_size]
y_batch = Y[i*batch_size:(i+1)*batch_size]
#print(f"Batch {i}")
# Accumulate gradients over the mini-batch
batch_loss = 0
batch_grads = None
for j in range(batch_size):
#print(f"Sample {j}")
x = x_batch[j]
y = y_batch[j]
loss, grads = step(model, x, y)
batch_loss += loss
if batch_grads is None:
batch_grads = grads
else:
batch_grads = tree_map(lambda g1, g2: g1 + g2, batch_grads, grads)
# Average the gradients over the batch
batch_grads = tree_map(lambda g: g / batch_size, batch_grads)
# All-reduce to average gradients across all processes
batch_grads = all_reduce_grads(batch_grads, num_processes)
# Update the model with the averaged gradients
optimizer.update(model, batch_grads)
epoch_loss += batch_loss
print(f"Rank: {rank} about to print")
# Trying to print causes the code to hang
#print(f"Epoch {epoch + 1}, Loss: {mx.distributed.all_sum(epoch_loss) / num_batches}")
print(f"Rank: {rank} finished")
To me this looks like a synchronization issue because the printing also works when I run through MPI but with a single process either on the local or the remote machine.
That kind of hang is most likely from a deadlock. Without seeing the rest of your code its hard to know. But usually this sort of thing happens when the processes are out of sync (e.g. because one of them is processing more examples).
If you can provide a full repro I can take a look and see if there is a deeper issue.
You can comment out (ignore) the barrier() line, I was attempting to create an ad hoc barrier implementation but it is not functional.
The model is not intended to do anything useful, so don't expect anything interesting there except perhaps bugs :)
This runs on a single process either through PyCharm (no MPI) or through MPI with "-np 1"
This is the problem:
if rank == 0:
print(f"Epoch {epoch + 1}, Loss: {mx.distributed.all_sum(epoch_loss) / num_batches}")
Every process has to participate in an all_sum
otherwise there is a hang (one process waiting for another which never called all_sum
.
Rearrange it like this:
loss = mx.distributed.all_sum(epoch_loss) / num_batches
if rank == 0:
print(f"Epoch {epoch + 1}, Loss: {loss.item()}")
This is still not working properly for me. If I replace with
loss = mx.distributed.all_sum(epoch_loss) / num_batches
if rank == 0:
print(f"Epoch {epoch + 1}, Loss: {loss.item()}")
then Rank-1 proceeds to go through the iterations while Rank-0 is stuck at printing which it never does.
Distributed available: True
Hostname: Asilomar: 1
Number of processes: 2
Number of Samples Per Process: 500
Distributed available: True
Hostname: PaloAlto: 0
Number of processes: 2
Number of Samples Per Process: 500
Rank: 1 about to print
Rank: 0 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 finished
If I replace with,
loss = mx.distributed.all_sum(epoch_loss) / num_batches
print(f"Epoch {epoch + 1}, Loss: {loss.item()}")
then both stay stuck:
Distributed available: True
Hostname: Asilomar: 1
Number of processes: 2
Number of Samples Per Process: 500
Distributed available: True
Hostname: PaloAlto: 0
Number of processes: 2
Number of Samples Per Process: 500
Rank: 1 about to print
Rank: 0 about to print
Ah interesting, that's a bit of a gotcha, sorry I think I told you the wrong thing. Try the following (notice I moved the item
):
loss = mx.distributed.all_sum(epoch_loss) / num_batches
loss = loss.item()
if rank == 0:
print(f"Epoch {epoch + 1}, Loss: {loss}")
Although you say this hangs as well? That's unexpected..
loss = mx.distributed.all_sum(epoch_loss) / num_batches
print(f"Epoch {epoch + 1}, Loss: {loss.item()}")
Yes, actually replacing the print statement with an mx.eval(loss) produces the same effect.
print(f"Rank: {rank} about to print")
# Trying to print causes the code to hang
loss = mx.distributed.all_sum(epoch_loss) / num_batches
mx.eval(loss)
Hmm, I'm not getting a hang there if I use the code like this:
loss = mx.distributed.all_sum(epoch_loss) / num_batches
loss = loss.item()
print(f"Epoch {epoch + 1}, Loss: {loss}")
Where are you running this? Is it on a large machine? Is it across the network or just locally? Do you get a hang if you run it locally? E.g. just running:
mpirun -n 2 python MLX_mpi_example3.py
I am running on two Mac Studios connected via Thunderbolt. I get the same if I attempt to run 2 processes on just one of the machines and also I get the same with both home-brew MPI and anaconda MPI.
Just to be sure we are running the same thing. Could you send exactly what you are running (with the print fix) and the command you used to run it locally?
Also what version of MLX are you using?
for epoch in range(20): # Number of epochs
epoch_loss = mx.array([0.0])
for i in range(num_batches):
x_batch = X[i*batch_size:(i+1)*batch_size]
y_batch = Y[i*batch_size:(i+1)*batch_size]
#print(f"Batch {i}")
# Accumulate gradients over the mini-batch
batch_loss = mx.array([0.0])
batch_grads = None
for j in range(batch_size):
#print(f"Sample {j}")
x = x_batch[j]
y = y_batch[j]
loss, grads = step(model, x, y)
batch_loss += loss
if batch_grads is None:
batch_grads = grads
else:
batch_grads = tree_map(lambda g1, g2: g1 + g2, batch_grads, grads)
# Average the gradients over the batch
batch_grads = tree_map(lambda g: g / batch_size, batch_grads)
# All-reduce to average gradients across all processes
batch_grads = all_reduce_grads(batch_grads, num_processes)
# Update the model with the averaged gradients
optimizer.update(model, batch_grads)
epoch_loss += batch_loss
print(f"Rank: {rank} about to print")
# Trying to print causes the code to hang
loss = mx.distributed.all_sum(epoch_loss) / num_batches
loss = loss.item()
print(f"Epoch {epoch + 1}, Loss: {loss}")
print(f"Rank: {rank} finished")
pip show mlx
Name: mlx
Version: 0.15.0.dev20240614+2d6cd47
/opt/homebrew/bin/mpirun -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ --host PaloAlto:2 /Users/m2/anaconda3/envs/pythonProject_StreamLit/bin/python /Users/m2/PycharmProjects/pythonProject_StreamLit/MLX_mpi_example3.py
Distributed available: True
Hostname: PaloAlto: 0
Number of processes: 2
Number of Samples Per Process: 500
Distributed available: True
Hostname: PaloAlto: 1
Number of processes: 2
Number of Samples Per Process: 500
Rank: 0 about to print
Rank: 1 about to print
Epoch 1, Loss: 3.787351131439209
Epoch 1, Loss: 3.787351131439209
Rank: 1 about to print
Rank: 0 about to print
Epoch 2, Loss: 16.894447326660156
Epoch 2, Loss: 16.894447326660156
Rank: 1 about to print
Rank: 0 about to print
Epoch 3, Loss: 3.4061923027038574
Epoch 3, Loss: 3.4061923027038574
Rank: 0 about to print
Rank: 1 about to print
Epoch 4, Loss: 2.3777270317077637
Epoch 4, Loss: 2.3777270317077637
Rank: 1 about to print
Rank: 0 about to print
Epoch 5, Loss: 3.5566701889038086
Epoch 5, Loss: 3.5566701889038086
Rank: 1 about to print
Rank: 0 about to print
Oh interesting. So it only hangs if you run over the network? It might be good to put an eval after you reduce the grads so we can see if it is actually hanging at the loss reduction or somewhere else:
# All-reduce to average gradients across all processes
batch_grads = all_reduce_grads(batch_grads, num_processes)
mx.eval(batch_grads) # <--- add that
Indeed, it hangs here too if I add it and run over the network (but still does not hang on a single machine w/ 2 process)
# All-reduce to average gradients across all processes
batch_grads = all_reduce_grads(batch_grads, num_processes)
mx.eval(batch_grads) # <--- add that
Have you tried something really simple just to debug the connection? Something like the following:
import mlx.core as mx
world = mx.distributed.init()
x = mx.distributed.all_sum(mx.ones(10))
print(world.rank(), x)
And it runs over the network with 2 process but on the remote machine
Yes, it I did try the simple case but let me repeat now jus to make sure nothing has changed since
import mlx.core as mx
import os
import socket
hostname = socket.gethostname()
world = mx.distributed.init()
x = mx.distributed.all_sum(mx.ones(10))
print(f"Distributed available: {mx.distributed.is_available()}")
print(f"Hostname: {hostname}: {world.rank()}, {x}")
/opt/homebrew/bin/mpirun -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ --host PaloAlto:1,Asilomar:1 /Users/m2/anaconda3/envs/pythonProject_StreamLit/bin/python /Users/m2/PycharmProjects/pythonProject_StreamLit/MX_mpi.py
Distributed available: True
Distributed available: True
Hostname: PaloAlto: 0, array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
Hostname: Asilomar: 1, array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
Great!
So it works fine for me even over a network. It is just quite slow. Could you try the same script but with smaller sizes (like decrease the input and output dimension by a factor of 100 or 1000) and see if it works?
OK, I confirm that with these sizes it works but slow
total_num_samples = 1000
in_dims = 2000
out_dims = 2000
Something interesting to add. The Mac Studios are connect via Thunderbolt and are on the same Lan via Wifi. The speed increases in the following order:
Not sure why. Perhaps avoiding some resolution that has to be done.
In other words this is pretty fast now:
/opt/homebrew/bin/mpirun -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ --host 10.0.0.2:1,10.0.0.1:1 /Users/m2/anaconda3/envs/pythonProject_StreamLit/bin/python /Users/m2/PycharmProjects/pythonProject_StreamLit/MLX_mpi_example3.py
I should add that with Option 3 (specifying the Thunderbolt Bridge IPs) it does not matter if wifi is on or off.
Some rough times (done on the command line with time for a quick and dirty comparison):
Option 1: 274.46s (machines connected by both Thunderbolt and wifi, hostnames specified on cmd line)
Option 2: 40.36s (machines connected byThunderbolt only, hostnames specified on cmd line)
Option 3: 11.29s (Thunderbolt Bridge IPs specified instead of Hostnames on cmd line)
So, Option 3 is roughly 25x faster than Option 1.
CC @angeloskath @jagrit06
Those are very interesting results, thanks for sharing. I think we'll need to figure out what's going on there and document a bit the recommended setup.
CC @angeloskath @jagrit06
Let me know if you would like me to run any more tests on my hardware setup here, I would be happy to help. Over the next week I will try to make this a 4 Studio cluster connected via Thunderbolt to see how it will scale.
@sck-at-ucy Good note on the difference between option 1 and 2 - OpenMPI is happy to find any way to route the message between the machines, so sometimes when we call mpirun
, it is useful to use the --mca btl_tcp_if_include "bridge0,en0,<interface from ifconfig or IP in CIDR notation>
option - there is also a corresponding --mca btl_tcp_if_exclude
option where you can add the IP of the wifi interface
If that sounds useful, maybe I could make a note and add it to the documentation
I have now progressed from debugging the MPI communication to running an example of a distributed training of an MLP model on two machines. I have been monitoring the CPU and GPU utilization on the two machines and it seems that when run through MPI the code is loaded on the CPUs not the GPUs. Perhaps I misunderstood the capabilities of the distributed computing implementation? I thought that once the process was launched on each machine that it would load on the GPUs (I have specified
mx.set_default_device(mx.gpu)
), but perhaps this is not the case when is done though MPI?