ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
16.23k stars 925 forks source link

[BUG] all_reduce_grads() fails with a Transformer model for number of nodes > 1 #1226

Closed sck-at-ucy closed 2 weeks ago

sck-at-ucy commented 2 months ago

Describe the bug While all_reduce_grads defined as per the documentation example

def all_reduce_grads(grads):
    N = mx.distributed.init().size()
    if N == 1:
        return tree_map(
            lambda x: mx.distributed.all_sum(x), grads )
    else:
        return tree_map(
            lambda x: mx.distributed.all_sum(x), grads ) # !!!!! TESTING Normally should  / N

worked for a simple model, when I try to use it a fairly large Transformer model it only works if I am running on a single node (i.e. N=1). If I try to use it with N>1, the code goes into a zombie state and the GPU stops being utilized effectively and the code never recovers from that state. To debug I tried to print out the reduced grads returned by all_reduce_grads(). With a single note, I get what I expected and looks fine, see attached file: GradsforSingleNode.txt .

With N>2, the print operation itself causes the code to crash with

Host in loop train & validate:  PaloAlto
[Asilomar][[59595,1],1][btl_tcp_frag.c:228:mca_btl_tcp_frag_recv] mca_btl_tcp_frag_recv: readv error (0x63fe87700, 97569536)
    Bad address(1)

[Asilomar:00000] *** An error occurred in Socket closed
[Asilomar:00000] *** reported by process [3905617921,1]
[Asilomar:00000] *** on a NULL communicator
[Asilomar:00000] *** Unknown error
[Asilomar:00000] *** MPI_ERRORS_ARE_FATAL (processes in this communicator will now abort,
[Asilomar:00000] ***    and MPI will try to terminate your MPI job as well)
--------------------------------------------------------------------------
prterun has exited due to process rank 1 with PID 10708 on node Asilomar calling
"abort". This may have caused other processes in the application to be
terminated by signals sent by prterun (as reported here).
--------------------------------------------------------------------------

To Reproduce

Include code snippet


            if num_train_batches == 1049:
                total_batch_grads_reduced = all_reduce_grads(batch_grads)
                if rank == 0:
                    print(total_batch_grads_reduced)

Expected behavior I expected to get back the averaged grads and the next step would have been to use optimizer.update() with the all_reduced grads.

Desktop (please complete the following information):

Additional context After successfully running a simpler model on up to 4 nodes, I tried to do the same with the Transformer model, but then I run into this trouble that prevents me from reducing the grads across nodes. The issue happens with any number of nodes > 2.

awni commented 2 months ago

Could you please share the code or some simplified version to reproduce this?

sck-at-ucy commented 2 months ago

Did that via direct email (code to complex to provide context here)

sck-at-ucy commented 1 month ago

One more piece of information. I wanted to make sure that the dictionaries of batch_grads that are fed to all_reduce_grads() are valid even in the case when there are multiple nodes. I put together this little check

def check_for_bad_values(grads):
    def recursive_check(value):
        if isinstance(value, dict):
            for sub_key, sub_value in value.items():
                if not recursive_check(sub_value):
                    return False
        elif isinstance(value, list):
            for item in value:
                if not recursive_check(item):
                    return False
        else:
            if mx.any(mx.isnan(value)) or mx.any(mx.isinf(value)):
                return False
        return True

    for key, value in grads.items():
        print(f"Checking key: {key}, type of value: {type(value)}")
        if not recursive_check(value):
            return False, key, value
    return True, None, None

def validate_grads(grads1):
    # Check for bad values
    valid1, key1, value1 = check_for_bad_values(grads1)

    if not valid1:
        print(f"Invalid value found in grads1 at key {key1}: {value1}")
        return False

    return True

and then inside the training loop:

            if num_train_batches == 1049:
                if validate_grads(batch_grads):
                    print(f'{hostname}: Gradients are valid')
                else:
                    print(f'{hostname}: Gradients are INVALID')

The result seems to clean for both nodes:

Checking key: projection_spatial_enc, type of value: <class 'dict'>
Checking key: projection_spatial_enc, type of value: <class 'dict'>
Checking key: positional_encoding_y, type of value: <class 'dict'>
Checking key: positional_encoding_x, type of value: <class 'dict'>
Checking key: positional_encoding_t, type of value: <class 'dict'>
Checking key: transformer_encoder, type of value: <class 'dict'>
Checking key: positional_encoding_y, type of value: <class 'dict'>
Checking key: positional_encoding_x, type of value: <class 'dict'>
Checking key: positional_encoding_t, type of value: <class 'dict'>
Checking key: transformer_encoder, type of value: <class 'dict'>
Checking key: output_projection, type of value: <class 'dict'>
Checking key: output_projection, type of value: <class 'dict'>
Checking key: diffusivity_embedding, type of value: <class 'dict'>
Checking key: diffusivity_embedding, type of value: <class 'dict'>
Checking key: layer_normalizer, type of value: <class 'dict'>
Checking key: layer_normalizer, type of value: <class 'dict'>
Checking key: mask, type of value: <class 'mlx.core.array'>
Checking key: mask, type of value: <class 'mlx.core.array'>
Asilomar: Gradients are valid
PaloAlto: Gradients are valid
awni commented 1 month ago

I don't know if it's related but the code @sck-at-ucy shared does not run on a single machine (M2 Ultra) for the multiprocess case. I believe it's a GPU timeout issue. I haven't tried it on multiple machines yet though. See related issue https://github.com/ml-explore/mlx/issues/1231

sck-at-ucy commented 1 month ago

I have made some progress with the code. It is still the case that all_reduce_grads() fails but this pushed me towards a different strategy that works. The idea is to aggregate the loss from all nodes (the dataset is divided equally among them, and they each feed the transformer batches of the same size). Using an mx.distributed.all_sum() each note aggregates the entire loss and based on that computes the gradients. I assume the gradients should be the same within round off since all nodes are computing them based on the same aggregated loss. Thus I avoid completely averaging gradients. If I keep the same batch size as used on a single node, I think am essentially training the model using an effective batch size that is larger (the factor is the number of nodes). This means either re-adjusting learning rate schedules and optimizing or reducing batch size and optimizing between speed up gains per iteration and good learning. I've been able to train the model but I'm still figuring out the best combination. So mx.distributed.all_sum() can handle the loss aggregation. This of course does not address the issue with all_reduce_grads but the problem has led me to think of this alternative way of doing a distributed training that seems to work fairly well, at least for my physics problem.

angeloskath commented 1 month ago

@sck-at-ucy I think this may be a network instability or something along those lines. I can train over Ethernet for hundreds of thousands of iterations so I can't reproduce this. Let me know if there is more information on the original bug report otherwise I am inclined to close this.

sck-at-ucy commented 1 month ago

@angeloskath My attention has been on the issues I have been having with continuing training after reloading states from file and so I have been running on a single node trying to understand that. The status with distributed computing as of a couple of weeks ago was the following for me: if I decreased significantly the size of the Transformer model all_reduce_grads() would work (in the sense that the training loop continued as opposed to being stuck for the larger model), GPU utilization was impacted (as expected), but importantly the training performance (i.e. loss decrease trend) was degraded compared to running the same size problem on a single node, so technically it worked for smaller models, but putting everything together, it seemed the distributed computing was not offering a clear benefit with my Transformer model. Might invest on an ethernet switch and try over ethernet.

angeloskath commented 2 weeks ago

I think this failure had to do with IP over thunderbolt across many machines so there is not something we can do from MLX. Over ethernet it works quite reliably.

If this is not the case then feel free to reopen.