luyug / GradCache

Run Effective Large Batch Contrastive Learning Beyond GPU/TPU Memory Constraint
Apache License 2.0
327 stars 19 forks source link

Questions about training #28

Open MikeDean2367 opened 3 months ago

MikeDean2367 commented 3 months ago

Hi, it's a great work!

We have three inputs designated as i1, i2, and i3, which are to be processed by the llama-7b. For input i1, I will extract two hidden states at two distinct locations and label them p11 and p12, respectively. Regarding the remaining inputs, i2 and i3, I will select a single hidden state for each, which will be denoted as n21 and n31 correspondingly.

In this setup, p11 paired with n21 constitutes a positive pair, whereas p11 coupled with n22 forms a negative pair. Meanwhile, p12 paired with n22 constitutes a positive pair, whereas p12 coupled with n21 forms a negative pair. My objective is to compute the InfoNCE loss between these pairs.

So I set the get_rep_fn in the class GradCache to handle the different situations. Here is a sample snippet or a piece of example code:

def get_rep_fn(x):
    if x.label == 2:
        return [x.e1, x.e2]
    else:
        return [x.e1]

In the same time, I changed the following code from append to extend: https://github.com/luyug/GradCache/blob/0c33638cb27c2519ad09c476824d550589a8ec38/src/grad_cache/grad_cache.py#L187 https://github.com/luyug/GradCache/blob/0c33638cb27c2519ad09c476824d550589a8ec38/src/grad_cache/grad_cache.py#L270 I'd like to inquire about the correctness of the gradient computation. Could you please confirm if it's being done accurately?

Thanks!