We have three inputs designated as i1, i2, and i3, which are to be processed by the llama-7b. For input i1, I will extract two hidden states at two distinct locations and label them p11 and p12, respectively. Regarding the remaining inputs, i2 and i3, I will select a single hidden state for each, which will be denoted as n21 and n31 correspondingly.
In this setup, p11 paired with n21 constitutes a positive pair, whereas p11 coupled with n22 forms a negative pair. Meanwhile, p12 paired with n22 constitutes a positive pair, whereas p12 coupled with n21 forms a negative pair. My objective is to compute the InfoNCE loss between these pairs.
So I set the get_rep_fn in the class GradCache to handle the different situations. Here is a sample snippet or a piece of example code:
Hi, it's a great work!
We have three inputs designated as
i1
,i2
, andi3
, which are to be processed by the llama-7b. For inputi1
, I will extract two hidden states at two distinct locations and label themp11
andp12
, respectively. Regarding the remaining inputs,i2
andi3
, I will select a single hidden state for each, which will be denoted asn21
andn31
correspondingly.In this setup,
p11
paired withn21
constitutes a positive pair, whereasp11
coupled withn22
forms a negative pair. Meanwhile,p12
paired withn22
constitutes a positive pair, whereasp12
coupled withn21
forms a negative pair. My objective is to compute the InfoNCE loss between these pairs.So I set the
get_rep_fn
in the classGradCache
to handle the different situations. Here is a sample snippet or a piece of example code:In the same time, I changed the following code from
append
toextend
: https://github.com/luyug/GradCache/blob/0c33638cb27c2519ad09c476824d550589a8ec38/src/grad_cache/grad_cache.py#L187 https://github.com/luyug/GradCache/blob/0c33638cb27c2519ad09c476824d550589a8ec38/src/grad_cache/grad_cache.py#L270 I'd like to inquire about the correctness of the gradient computation. Could you please confirm if it's being done accurately?Thanks!