greentfrapp / lucent

Lucid library adapted for PyTorch
Apache License 2.0
597 stars 89 forks source link

fix parameters for LocalResponseNorm0 and LocalResponseNorm1 #51

Open liv0617 opened 3 months ago

liv0617 commented 3 months ago

The alpha and k values for LocalResponseNorm0 and LocalResponseNorm1 are incorrect. The code originally used to construct lucent/modelzoo/inceptionv1/InceptionV1.py has since been updated to correct this, but that change isn't reflected here.

In order to validate the correct values, one can do the following:

  1. Download the original InceptionV1 protobuf file used in Lucid.

  2. Open and inspect the contents:

import tensorflow as tf

graph_def = tf.compat.v1.GraphDef.FromString(open("path/to/saved_model.pb", "rb").read())
for x in graph_def.node[:]:
  if x.name not in ["localresponsenorm0", "localresponsenorm1"]:
    continue
  print(x)

and then the various attributes of the LocalResponseNorm can be examined. Example output for localresponsenorm1:

name: "localresponsenorm1"
op: "LRN"
input: "conv2d2"
device: "/cpu:0"
attr {
  key: "alpha"
  value {
    f: 9.999999747378752e-05
  }
}
attr {
  key: "beta"
  value {
    f: 0.5
  }
}
attr {
  key: "bias"
  value {
    f: 2.0
  }
}
attr {
  key: "depth_radius"
  value {
    i: 5
  }
}

LocalResponseNorm layers differ slightly in TensorFlow and in PyTorch:

greentfrapp commented 3 months ago

Thanks @liv0617 for the update! I haven't been working on this for a long time now. I'm wondering is there a minimal example to test the effect of this change?

liv0617 commented 1 month ago

Sorry for missing this @greentfrapp!

--

It's a bit hard to test this because (1) this bug changes model behavior, but doesn't completely break it; and (2) to rigorously test, we would need a model implementation we trust in order to validate it, which we don't have in PyTorch if we don't trust this implementation.

The experiment I did which made me pretty confident the new version is correct (besides inspect InceptionV1's parameters as shown above) was reproducing the curve detectors synthetic data plots form Cammarata et al. https://distill.pub/2020/circuits/curve-detectors/ . By default, the plots are very different. But with this fix, I reproduce the plots.