ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
17.06k stars 988 forks source link

Distributed computing not utilizing GPUs #1210

Open sck-at-ucy opened 4 months ago

sck-at-ucy commented 4 months ago

I have now progressed from debugging the MPI communication to running an example of a distributed training of an MLP model on two machines. I have been monitoring the CPU and GPU utilization on the two machines and it seems that when run through MPI the code is loaded on the CPUs not the GPUs. Perhaps I misunderstood the capabilities of the distributed computing implementation? I thought that once the process was launched on each machine that it would load on the GPUs (I have specified mx.set_default_device(mx.gpu) ), but perhaps this is not the case when is done though MPI?

awni commented 4 months ago

The code should definitely still run on the GPU. Its possible you are being bottlenecked by communication latency so it looks like the GPU is not used. One thing to check is that you are using the GPU in the same code even without doing any MPI communications (e.g. no all reduce or gather). In that case it should run fast on the GPU. If that's not the case, then I would double check even the single process setup.

sck-at-ucy commented 4 months ago

If I run the same code as is locally in PyCharm and not on the command line through MPI, it shows steady 96% GPU utilization. When I run it through MPI, GPU utilization on both machines hovers around 15% with +/- 3%.

awni commented 4 months ago

Yea exactly, so it looks like either communication latency or bandwidth is a bottleneck for you. Did you see this section in the docs on tuning all reduce?

In general there's a few things to keep in mind here:

sck-at-ucy commented 4 months ago

Thank you, yes it makes sense. The example I have tried was very crude, I guess I was impatient to try it 🤓, with no batching so I think all reduce was called too frequently. I will refine to minimize the frequency of communication.

awni commented 4 months ago

Great, keep us posted on how it goes!

sck-at-ucy commented 4 months ago

Will do.

sck-at-ucy commented 4 months ago

After implementing training in batches I get good speeds but I have run into a problem I do not understand. The issue surfaces when I try to print the loss. If I run the code in a single process without MPI it works fine. However, if I try to do that in distributed model, the code hangs and never completes. I've tried various variations but none worked. I also think it might be useful to implement an mx.distrtibuted.finalize()or world.finalize() functionality for clean termination.

# We can also force evaluate all parameters to initialize the model
mx.eval(model.parameters())

optimizer = optim.Adam(learning_rate=0.0001)

# A simple loss function.
def l2_loss(model, x, y):
    y_hat = model(x)
    return mx.array(y_hat - y).square().mean()

def all_reduce_grads(grads, N):
    if N == 1:
        return grads
    return tree_map(
            lambda x: mx.distributed.all_sum(x) / N,
            grads)

def step(model, x, y):
    loss_and_grad_fn = nn.value_and_grad(model, l2_loss)
    loss, grads = loss_and_grad_fn(model, x, y)
    return loss, grads

# Training loop with mini-batch processing
batch_size = 10
num_batches = num_samples_per_process // batch_size

for epoch in range(20):  # Number of epochs
    epoch_loss = mx.array([0.0])
    for i in range(num_batches):
        x_batch = X[i*batch_size:(i+1)*batch_size]
        y_batch = Y[i*batch_size:(i+1)*batch_size]
        #print(f"Batch {i}")

        # Accumulate gradients over the mini-batch
        batch_loss = 0
        batch_grads = None
        for j in range(batch_size):
            #print(f"Sample {j}")
            x = x_batch[j]
            y = y_batch[j]
            loss, grads = step(model, x, y)
            batch_loss += loss
            if batch_grads is None:
                batch_grads = grads
            else:
                batch_grads = tree_map(lambda g1, g2: g1 + g2, batch_grads, grads)
        # Average the gradients over the batch
        batch_grads = tree_map(lambda g: g / batch_size, batch_grads)

        # All-reduce to average gradients across all processes
        batch_grads = all_reduce_grads(batch_grads, num_processes)

        # Update the model with the averaged gradients
        optimizer.update(model, batch_grads)

        epoch_loss += batch_loss

    print(f"Rank: {rank} about to print")

    # Trying to print causes the code to hang
    #print(f"Epoch {epoch + 1}, Loss: {mx.distributed.all_sum(epoch_loss) / num_batches}")

print(f"Rank: {rank} finished")
sck-at-ucy commented 4 months ago

To me this looks like a synchronization issue because the printing also works when I run through MPI but with a single process either on the local or the remote machine.

awni commented 4 months ago

That kind of hang is most likely from a deadlock. Without seeing the rest of your code its hard to know. But usually this sort of thing happens when the processes are out of sync (e.g. because one of them is processing more examples).

If you can provide a full repro I can take a look and see if there is a deeper issue.

sck-at-ucy commented 4 months ago
  1. You can comment out (ignore) the barrier() line, I was attempting to create an ad hoc barrier implementation but it is not functional.

  2. The model is not intended to do anything useful, so don't expect anything interesting there except perhaps bugs :)

  3. This runs on a single process either through PyCharm (no MPI) or through MPI with "-np 1"

MLX_mpi_example3.py.txt

awni commented 4 months ago

This is the problem:

    if rank == 0:
        print(f"Epoch {epoch + 1}, Loss: {mx.distributed.all_sum(epoch_loss) / num_batches}")

Every process has to participate in an all_sum otherwise there is a hang (one process waiting for another which never called all_sum.

Rearrange it like this:

    loss = mx.distributed.all_sum(epoch_loss) / num_batches
    if rank == 0:
        print(f"Epoch {epoch + 1}, Loss: {loss.item()}")
sck-at-ucy commented 4 months ago

This is still not working properly for me. If I replace with

    loss = mx.distributed.all_sum(epoch_loss) / num_batches
    if rank == 0:
        print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

then Rank-1 proceeds to go through the iterations while Rank-0 is stuck at printing which it never does.

Distributed available: True
Hostname: Asilomar: 1
Number of processes: 2
Number of Samples Per Process: 500
Distributed available: True
Hostname: PaloAlto: 0
Number of processes: 2
Number of Samples Per Process: 500
Rank: 1 about to print
Rank: 0 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 about to print
Rank: 1 finished

If I replace with,

    loss = mx.distributed.all_sum(epoch_loss) / num_batches
    print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

then both stay stuck:

Distributed available: True
Hostname: Asilomar: 1
Number of processes: 2
Number of Samples Per Process: 500
Distributed available: True
Hostname: PaloAlto: 0
Number of processes: 2
Number of Samples Per Process: 500
Rank: 1 about to print
Rank: 0 about to print
awni commented 4 months ago

Ah interesting, that's a bit of a gotcha, sorry I think I told you the wrong thing. Try the following (notice I moved the item):

    loss = mx.distributed.all_sum(epoch_loss) / num_batches
    loss = loss.item()
    if rank == 0:
        print(f"Epoch {epoch + 1}, Loss: {loss}")
awni commented 4 months ago

Although you say this hangs as well? That's unexpected..

   loss = mx.distributed.all_sum(epoch_loss) / num_batches
   print(f"Epoch {epoch + 1}, Loss: {loss.item()}")
sck-at-ucy commented 4 months ago

Yes, actually replacing the print statement with an mx.eval(loss) produces the same effect.

    print(f"Rank: {rank} about to print")
    # Trying to print causes the code to hang
    loss = mx.distributed.all_sum(epoch_loss) / num_batches
    mx.eval(loss)
awni commented 4 months ago

Hmm, I'm not getting a hang there if I use the code like this:

    loss = mx.distributed.all_sum(epoch_loss) / num_batches
    loss = loss.item()
    print(f"Epoch {epoch + 1}, Loss: {loss}")

Where are you running this? Is it on a large machine? Is it across the network or just locally? Do you get a hang if you run it locally? E.g. just running:

mpirun -n 2 python MLX_mpi_example3.py
sck-at-ucy commented 4 months ago

I am running on two Mac Studios connected via Thunderbolt. I get the same if I attempt to run 2 processes on just one of the machines and also I get the same with both home-brew MPI and anaconda MPI.

awni commented 4 months ago

Just to be sure we are running the same thing. Could you send exactly what you are running (with the print fix) and the command you used to run it locally?

awni commented 4 months ago

Also what version of MLX are you using?

sck-at-ucy commented 4 months ago
  1. The code
for epoch in range(20):  # Number of epochs
    epoch_loss = mx.array([0.0])
    for i in range(num_batches):
        x_batch = X[i*batch_size:(i+1)*batch_size]
        y_batch = Y[i*batch_size:(i+1)*batch_size]
        #print(f"Batch {i}")

        # Accumulate gradients over the mini-batch
        batch_loss = mx.array([0.0])
        batch_grads = None
        for j in range(batch_size):
            #print(f"Sample {j}")
            x = x_batch[j]
            y = y_batch[j]
            loss, grads = step(model, x, y)
            batch_loss += loss
            if batch_grads is None:
                batch_grads = grads
            else:
                batch_grads = tree_map(lambda g1, g2: g1 + g2, batch_grads, grads)
        # Average the gradients over the batch
        batch_grads = tree_map(lambda g: g / batch_size, batch_grads)

        # All-reduce to average gradients across all processes
        batch_grads = all_reduce_grads(batch_grads, num_processes)

        # Update the model with the averaged gradients
        optimizer.update(model, batch_grads)

        epoch_loss += batch_loss

    print(f"Rank: {rank} about to print")
    # Trying to print causes the code to hang
    loss = mx.distributed.all_sum(epoch_loss) / num_batches
    loss = loss.item()
    print(f"Epoch {epoch + 1}, Loss: {loss}")

print(f"Rank: {rank} finished")
  1. MLX Version:
pip show mlx
Name: mlx
Version: 0.15.0.dev20240614+2d6cd47
  1. The command line: here I stand corrected, if I run on a single machine locally it runs with the command line as given below
/opt/homebrew/bin/mpirun  -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ --host PaloAlto:2 /Users/m2/anaconda3/envs/pythonProject_StreamLit/bin/python /Users/m2/PycharmProjects/pythonProject_StreamLit/MLX_mpi_example3.py
Distributed available: True
Hostname: PaloAlto: 0
Number of processes: 2
Number of Samples Per Process: 500
Distributed available: True
Hostname: PaloAlto: 1
Number of processes: 2
Number of Samples Per Process: 500
Rank: 0 about to print
Rank: 1 about to print
Epoch 1, Loss: 3.787351131439209
Epoch 1, Loss: 3.787351131439209
Rank: 1 about to print
Rank: 0 about to print
Epoch 2, Loss: 16.894447326660156
Epoch 2, Loss: 16.894447326660156
Rank: 1 about to print
Rank: 0 about to print
Epoch 3, Loss: 3.4061923027038574
Epoch 3, Loss: 3.4061923027038574
Rank: 0 about to print
Rank: 1 about to print
Epoch 4, Loss: 2.3777270317077637
Epoch 4, Loss: 2.3777270317077637
Rank: 1 about to print
Rank: 0 about to print
Epoch 5, Loss: 3.5566701889038086
Epoch 5, Loss: 3.5566701889038086
Rank: 1 about to print
Rank: 0 about to print
awni commented 4 months ago

Oh interesting. So it only hangs if you run over the network? It might be good to put an eval after you reduce the grads so we can see if it is actually hanging at the loss reduction or somewhere else:

        # All-reduce to average gradients across all processes
        batch_grads = all_reduce_grads(batch_grads, num_processes)
        mx.eval(batch_grads) # <--- add that
sck-at-ucy commented 4 months ago

Indeed, it hangs here too if I add it and run over the network (but still does not hang on a single machine w/ 2 process)

        # All-reduce to average gradients across all processes
        batch_grads = all_reduce_grads(batch_grads, num_processes)
        mx.eval(batch_grads) # <--- add that
awni commented 4 months ago

Have you tried something really simple just to debug the connection? Something like the following:

import mlx.core as mx

world = mx.distributed.init()
x = mx.distributed.all_sum(mx.ones(10))
print(world.rank(), x)
sck-at-ucy commented 4 months ago

And it runs over the network with 2 process but on the remote machine

sck-at-ucy commented 4 months ago

Yes, it I did try the simple case but let me repeat now jus to make sure nothing has changed since

sck-at-ucy commented 4 months ago
import mlx.core as mx
import os
import socket

hostname = socket.gethostname()

world = mx.distributed.init()
x = mx.distributed.all_sum(mx.ones(10))
print(f"Distributed available: {mx.distributed.is_available()}")
print(f"Hostname: {hostname}: {world.rank()}, {x}")
/opt/homebrew/bin/mpirun  -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ --host PaloAlto:1,Asilomar:1 /Users/m2/anaconda3/envs/pythonProject_StreamLit/bin/python /Users/m2/PycharmProjects/pythonProject_StreamLit/MX_mpi.py  
Distributed available: True
Distributed available: True
Hostname: PaloAlto: 0, array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
Hostname: Asilomar: 1, array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
awni commented 4 months ago

Great!

awni commented 4 months ago

So it works fine for me even over a network. It is just quite slow. Could you try the same script but with smaller sizes (like decrease the input and output dimension by a factor of 100 or 1000) and see if it works?

sck-at-ucy commented 4 months ago

OK, I confirm that with these sizes it works but slow

total_num_samples = 1000
in_dims = 2000
out_dims = 2000
sck-at-ucy commented 4 months ago

Something interesting to add. The Mac Studios are connect via Thunderbolt and are on the same Lan via Wifi. The speed increases in the following order:

  1. Specify hostnames with both WiFi and Thunderbolt on = slowest
  2. Specify hostnames but turn WiFi off = 3x faster than option 1.
  3. Specify Thunderbolt bridge IPs instead of hostnames (at least 10x faster than option 2, probably more need to benchmark).

Not sure why. Perhaps avoiding some resolution that has to be done.

In other words this is pretty fast now:

/opt/homebrew/bin/mpirun -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ --host 10.0.0.2:1,10.0.0.1:1 /Users/m2/anaconda3/envs/pythonProject_StreamLit/bin/python /Users/m2/PycharmProjects/pythonProject_StreamLit/MLX_mpi_example3.py

sck-at-ucy commented 4 months ago

I should add that with Option 3 (specifying the Thunderbolt Bridge IPs) it does not matter if wifi is on or off.

sck-at-ucy commented 4 months ago

Some rough times (done on the command line with time for a quick and dirty comparison):

Option 1: 274.46s (machines connected by both Thunderbolt and wifi, hostnames specified on cmd line) Option 2: 40.36s (machines connected byThunderbolt only, hostnames specified on cmd line)
Option 3: 11.29s (Thunderbolt Bridge IPs specified instead of Hostnames on cmd line)

So, Option 3 is roughly 25x faster than Option 1.

awni commented 4 months ago

CC @angeloskath @jagrit06

Those are very interesting results, thanks for sharing. I think we'll need to figure out what's going on there and document a bit the recommended setup.

sck-at-ucy commented 4 months ago

CC @angeloskath @jagrit06

Let me know if you would like me to run any more tests on my hardware setup here, I would be happy to help. Over the next week I will try to make this a 4 Studio cluster connected via Thunderbolt to see how it will scale.

jagrit06 commented 4 months ago

@sck-at-ucy Good note on the difference between option 1 and 2 - OpenMPI is happy to find any way to route the message between the machines, so sometimes when we call mpirun, it is useful to use the --mca btl_tcp_if_include "bridge0,en0,<interface from ifconfig or IP in CIDR notation> option - there is also a corresponding --mca btl_tcp_if_exclude option where you can add the IP of the wifi interface

If that sounds useful, maybe I could make a note and add it to the documentation