Closed Zasder3 closed 3 years ago
It was briefly explained in this comment. Each device received the entire embeddings of size 32768 x embed_dim
to be matmul'ed with the local embedding vectors, like:
# image_features.shape = [local_batch_size, embed_dim]
# text_features.shape = [local_batch_size, embed_dim]
all_image_features = all_gather(image_features) # shape = [global_batch_size, embed_dim]
all_text_features = all_gather(text_features) # shape = [global_batch_size, embed_dim]
logits_per_image = logit_scale * image_features @ all_text_features.t() # shape = [local_batch_size, global_batch_size]
logits_per_text = logit_scale * text_features @ all_image_features.t() # shape = [local_batch_size, global_batch_size]
where all_gather
is a function gathering and concatenating the tensors across GPUs. The logits are then used as inputs to the cross entropy loss for 32768-way classification, resulting in a loss value corresponding to local_batch_size
inputs in each GPU. DistributedDataParallel
takes care of averaging these across GPUs, which becomes equivalent to calculating the loss over 32768 x 32768
similarities.
I have a question about all_gather
used for training, I know there is two versions for all_gather
that can be implemented, one which records the gradients from all GPUs
and the other one does a detach
. I am wondering which one was used for training.
@mutasem-mattar the all_gather
function needs a full, undetached backward operation (calling reduce_scatter
or similar operations) for training to work most efficiently.
I have one follow up question. How did you manage to all_gather
the image embeddings and text embeddings and do dot product without getting OOM problem. specially your batch size is around 32768.
The memory that all-gathered features take is:
32768 (global batch size) x 1024 (largest CLIP embedding dim) x 2 (fp16) ≈ 64 MB
and the logits after the dot product would take:
32768 (global batch size) x 32768 / n_gpu (local batch size) x 2 (fp16) ≈ 4096 MB / n_gpu
so it's manageable.
It was briefly explained in this comment. Each device received the entire embeddings of size
32768 x embed_dim
to be matmul'ed with the local embedding vectors, like:# image_features.shape = [local_batch_size, embed_dim] # text_features.shape = [local_batch_size, embed_dim] all_image_features = all_gather(image_features) # shape = [global_batch_size, embed_dim] all_text_features = all_gather(text_features) # shape = [global_batch_size, embed_dim] logits_per_image = logit_scale * image_features @ all_text_features.t() # shape = [local_batch_size, global_batch_size] logits_per_text = logit_scale * text_features @ all_image_features.t() # shape = [local_batch_size, global_batch_size]
where
all_gather
is a function gathering and concatenating the tensors across GPUs. The logits are then used as inputs to the cross entropy loss for 32768-way classification, resulting in a loss value corresponding tolocal_batch_size
inputs in each GPU.DistributedDataParallel
takes care of averaging these across GPUs, which becomes equivalent to calculating the loss over32768 x 32768
similarities.
I'm confused about trainning implementation. When we do contrastive learning, the diagonal elements are positive samples. But if we calculate the loss locally with shape (local_batch_size, global_batch_size), how do we know the positive samples each GPU?
I have an unconfirmed idea that we may get the local_rank and know the distributed local batch "index" to get the corresponding position of the positive samples. But i don't know whether the local rank is exactly the "local batch index".
Within the paper you write:
I've seen many different interpretations of this line that vary between repo. How did the sharding function within the training code? Did each device receive the
32768x32768
matrix containing the inner-products? Did each device only compute a localNxN
matrix s.t.N << 32768
? Did each device recieve only 232768xEmbed_dim
matrices?