frederick0329 / TracIn

Implementation of Estimating Training Data Influence by Tracing Gradient Descent (NeurIPS 2020)
Apache License 2.0
219 stars 15 forks source link

influence calculation #2

Closed karasepid closed 3 years ago

karasepid commented 3 years ago

I think the equations below just hold for one checkpoint.

           lg_sim = np.sum(trackin_train['loss_grads'][i] * loss_grad)
           a_sim = np.sum(trackin_train['activations'][i] * activation)
           scores.append(lg_sim * a_sim)```

For the cases with more than one checkpoint equations need to be changed to 
        lg_sim = np.sum(trackin_train['loss_grads'][i] * loss_grad, axis=0)
        a_sim = np.sum(trackin_train['activations'][i] * activation, axis=0)
        scores.append(np.sum(lg_sim * a_sim))

which lg_sim and a_sim has dimension c (number of checkpoints)
frederick0329 commented 3 years ago

Thanks for checking this. I stacked the vectors from the checkpoints here:

@tf.function
def run(inputs):
  imageids, images, labels = inputs
  # ignore bias for simplicity
  loss_grads = []
  activations = []
  for mp, ml in zip(models_penultimate, models_last):
    h = mp(images)
    logits = ml(h)
    probs = tf.nn.softmax(logits)
    loss_grad = tf.one_hot(labels, 1000) - probs
    activations.append(h)
    loss_grads.append(loss_grad)

  # Using probs from last checkpoint
  probs, predicted_labels = tf.math.top_k(probs, k=1)

  return imageids, tf.stack(loss_grads, axis=-1), tf.stack(activations, axis=-1), labels, probs, predicted_labels

Does this make sense?

karasepid commented 3 years ago

My question is what is the dimension of trackin_train['loss_grads'][i] * loss_grad ? My understanding is, it is most probably #checkpoints by # classes (10 for cifar10). Which in this case summing over both dimension (lg_sim )and then multiply to (a_sim) does not sound correct to me. It should only sum over the samples, multiply with sum over samples version of a_sim. which would be in dimension of #checkpoints. Then it can sum over checkpoints.

        lg_sim = np.sum(trackin_train['loss_grads'][i] * loss_grad, axis=0)
        a_sim = np.sum(trackin_train['activations'][i] * activation, axis=0)
        scores.append(np.sum(lg_sim * a_sim))
frederick0329 commented 3 years ago

The dimension of trackin_train['loss_grad'] is 2D: num_trainin_data (number of weights num_checkpoints) The dimension of loss_grad is 1D: (number of weights * num_checkpoints)

for i in range(len(trackin_train['image_ids'])):

loops over all the images and calculate similarity one by one.

philferriere commented 3 years ago

Hi @karasepid ,

I was a little bit confused by this part of the code as well, as I couldn't recognize Equation (1) from the paper right away (Section 3.3 TraceInCP equation). Indeed, it has do do with the stacking trick @frederick0329 is using. I had to write the following code to convince myself, though ;)

# From the Resnet50 model definition used by Frederick:
# tracking_train["activations"]: activations (N, 2048 floats, num_ckpts)
# tracking_train["loss_grads"]: loss gradients (N, 1000 floats, num_ckpts)
#
# below use (5 floats, num_ckts of 3) to validate math:
# lg1 = np.random.random((5,3))
# lg2 = np.random.random((5,3))
# lg1_times_lg2 = lg1 * lg2  # (5,3)
# np.sum(lg1_times_lg2)
# Out[25]: 3.6332565631348537
# np.dot(lg1[:,0], lg2[:,0]) + np.dot(lg1[:,1], lg2[:,1]) + np.dot(lg1[:,2], lg2[:,2])
# Out[26]: 3.6332565631348537

The sum of dot products is really Equation (1), and the results are the same.

frederick0329 commented 3 years ago

Thank you @philferriere! Note that the stacking trick is important to leverage downstream nearest neighbor library which usually takes a vector.