frederick0329 / TracIn

Implementation of Estimating Training Data Influence by Tracing Gradient Descent (NeurIPS 2020)
Apache License 2.0
215 stars 15 forks source link

Q: Difference between the paper and the notebook #5

Closed philferriere closed 3 years ago

philferriere commented 3 years ago

Hello again, @frederick!

I have a question about a difference between the paper and the notebook you've kindly shared with us. To help capture the difference I'm interested in discussing here, I've rewritten the find() function in the notebook so it's more obvious:

def find_topk(tracin_train, eval_loss_grad=None, eval_activation=None, topk=50):
  """
  Find in the training set the topk proponents and opponents for ONE sample in the validation set
  - Use proponents to build a high-value training set much smaller than the original training set 
  - Inspect opponents to surface samples that may have mis-labelling issues
  Args:
    loss_grad: loss gradients for this validation sample (1000 floats, num_ckpts)
      recall that tracking_dict["loss_grads"]: loss gradients (N, 1000 floats, num_ckpts)
    activation: activations for this validation sample (2048 floats, num_ckpts)
      recall that tracking_dict["activations"]: activations (N, 2048 floats, num_ckpts)
    topk: number of opponents and proponents to return
  Returns:
    op, pp: topk opponents, topk proponents
  Sample use:
    op, pp = find_topk(tracin_train, tracin_eval['loss_grads'][idx], tracin_eval['activations'][idx])
  """
  # In the original notebook, the authors mutliply two scores (loss_grad_sim and activation_sim)
  if eval_loss_grad is None and eval_activation is None:
    raise ValueError('loss grad and activation cannot both be None.')

  # Walk through the entire list of training samples and score them
  scores = []
  if eval_loss_grad is not None and eval_activation is not None:
    for i in range(len(tracin_train['loss_grads'])):
      # Compute loss gradient similarity
      lg_sim = np.sum(tracin_train['loss_grads'][i] * eval_loss_grad)
      # Compute activation similarity
      a_sim = np.sum(tracin_train['activations'][i] * eval_activation)
      # Save final score
      scores.append(lg_sim * a_sim)  # not paper implementation
  elif eval_loss_grad is not None:
    for i in range(len(tracin_train['loss_grads'])):
      scores.append(np.sum(tracin_train['loss_grads'][i] * eval_loss_grad))  # paper implementation
  elif eval_activation is not None:
    for i in range(len(tracin_train['loss_grads'])):
      scores.append(np.sum(tracin_train['activations'][i] * eval_activation))  # not paper implementation    

  # Order the scores from smallest to largest (most negative to most positive)
  indices = np.argsort(scores)
...

Above, I've added a comment to the line that implements Equation (1) in the paper. What's interesting to me is that the formulation you shared in the notebook is different. It's the one where you compute both loss gradient similarity and activation similarity and compute the product between the two.

Would you mind sharing with us what's the motivation for using this formulation instead of the "regular" one?

Thanks again!

-- Phil

frederick0329 commented 3 years ago

The regular one would yield a huge weight (size of preliminary layer * number of classes) for each example and would cause memory / storage to blow up. We leverage the decomposition of grad(loss, weights of last fully connected layer) = outer product(loss grad, outputs of preliminary layer) through the chain rule. Details can be found in the appendix F of the paper. Appendix F - Fast Random Projections for Gradients of Fully-Connected Layers Note: random projection is optional and would lose information.

philferriere commented 3 years ago

Oh, stupid me, you're right. It's been staring me in the face all along. It's basically the Appendix F/bottom of page 14 fast/final expression o(m + n) instead of o(m x n) where m=2048 and n=1000, based on this network definition:

# For the resnet model definition, see https://github.com/frederick0329/TracIn/blob/master/imagenet/resnet50/resnet.py
# Layer[-3]:  
# x = tf.keras.layers.GlobalAveragePooling2D()(x) [2048 floats]
# Layer[-2]:  
# x = tf.keras.layers.Dense(
#     num_classes,
#     kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
#     kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
#     bias_regularizer=_gen_l2_regularizer(use_l2_regularizer),
#     name='fc1000')(x)  [1000 floats]
# Layer[-1]:
# # A softmax that is followed by the model loss cannot be done
# # in float16 due to numeric issues. So we pass dtype=float32.
# x = tf.keras.layers.Activation('softmax', dtype='float32')(x) [1000 probas]

Brilliant!