openai / CLIP

CLIP (Contrastive Language-Image Pretraining), Predict the most relevant text snippet given an image
MIT License
25.49k stars 3.27k forks source link

Batch Sharding Details #132

Closed Zasder3 closed 3 years ago

Zasder3 commented 3 years ago

Within the paper you write:

The calculation of embedding similarities was also sharded with individual GPUs comput- ing only the subset of the pairwise similarities necessary for their local batch of embeddings.

I've seen many different interpretations of this line that vary between repo. How did the sharding function within the training code? Did each device receive the 32768x32768 matrix containing the inner-products? Did each device only compute a local NxN matrix s.t. N << 32768? Did each device recieve only 2 32768xEmbed_dim matrices?

jongwook commented 3 years ago

It was briefly explained in this comment. Each device received the entire embeddings of size 32768 x embed_dim to be matmul'ed with the local embedding vectors, like:

# image_features.shape = [local_batch_size, embed_dim]
# text_features.shape = [local_batch_size, embed_dim]

all_image_features = all_gather(image_features)  # shape = [global_batch_size, embed_dim]
all_text_features = all_gather(text_features)    # shape = [global_batch_size, embed_dim]

logits_per_image = logit_scale * image_features @ all_text_features.t()  # shape = [local_batch_size, global_batch_size]
logits_per_text = logit_scale * text_features @ all_image_features.t()   # shape = [local_batch_size, global_batch_size]

where all_gather is a function gathering and concatenating the tensors across GPUs. The logits are then used as inputs to the cross entropy loss for 32768-way classification, resulting in a loss value corresponding to local_batch_size inputs in each GPU. DistributedDataParallel takes care of averaging these across GPUs, which becomes equivalent to calculating the loss over 32768 x 32768 similarities.

mutasem-mattar commented 3 years ago

I have a question about all_gather used for training, I know there is two versions for all_gather that can be implemented, one which records the gradients from all GPUs and the other one does a detach. I am wondering which one was used for training.

jongwook commented 3 years ago

@mutasem-mattar the all_gather function needs a full, undetached backward operation (calling reduce_scatter or similar operations) for training to work most efficiently.

mutasem-mattar commented 3 years ago

I have one follow up question. How did you manage to all_gather the image embeddings and text embeddings and do dot product without getting OOM problem. specially your batch size is around 32768.

jongwook commented 3 years ago

The memory that all-gathered features take is:

 32768 (global batch size) x 1024 (largest CLIP embedding dim) x 2 (fp16) ≈ 64 MB

and the logits after the dot product would take:

32768 (global batch size) x 32768 / n_gpu (local batch size) x 2 (fp16) ≈ 4096 MB / n_gpu

so it's manageable.

zsnoob commented 11 months ago

It was briefly explained in this comment. Each device received the entire embeddings of size 32768 x embed_dim to be matmul'ed with the local embedding vectors, like:

# image_features.shape = [local_batch_size, embed_dim]
# text_features.shape = [local_batch_size, embed_dim]

all_image_features = all_gather(image_features)  # shape = [global_batch_size, embed_dim]
all_text_features = all_gather(text_features)    # shape = [global_batch_size, embed_dim]

logits_per_image = logit_scale * image_features @ all_text_features.t()  # shape = [local_batch_size, global_batch_size]
logits_per_text = logit_scale * text_features @ all_image_features.t()   # shape = [local_batch_size, global_batch_size]

where all_gather is a function gathering and concatenating the tensors across GPUs. The logits are then used as inputs to the cross entropy loss for 32768-way classification, resulting in a loss value corresponding to local_batch_size inputs in each GPU. DistributedDataParallel takes care of averaging these across GPUs, which becomes equivalent to calculating the loss over 32768 x 32768 similarities.

I'm confused about trainning implementation. When we do contrastive learning, the diagonal elements are positive samples. But if we calculate the loss locally with shape (local_batch_size, global_batch_size), how do we know the positive samples each GPU?

I have an unconfirmed idea that we may get the local_rank and know the distributed local batch "index" to get the corresponding position of the positive samples. But i don't know whether the local rank is exactly the "local batch index".