greentfrapp / lucent

Lucid library adapted for PyTorch
Apache License 2.0
597 stars 89 forks source link

Generating a batch of optimal stimuli, one for each unit in a layer #28

Closed arnaghosh closed 3 years ago

arnaghosh commented 3 years ago

Hi, I was trying to use Lucent to generate optimal stimuli for several units/neurons of a layer parallely. So, I figured I would leverage the batch processing. As illustrated in the neuron interaction tutorial notebook, I was passing a sum of objectives to the render.render_vis() function. Here is a toy example of what I want and my approach: Units to be visualized = [10,20,30] Layer = 'readout_fc' tot_objective = objectives.channel("readout_fc",10,batch=0)+objectives.channel("readout_fc",20,batch=1)+objectives.channel("readout_fc",30,batch=2) param_f = lambda:param.image(135,batch=3) imgs = render.render_vis(model,tot_objective,param_f=param_f,preprocess=False,fixed_input_image_size=135)

The parameter settings works beautifully when I try one unit.😄 However, I wasn't sure if this is the correct way to approach this for multiple units in parallel (this gives me seperate images for each unit). Also when the number of units is more, I was hoping to avoid writing it out individually or run an explicit for loop to compute the objective. I tried using reduce as below: neurons = [10,20,30] tot_objective = reduce(lambda x,y: x+objectives.channel("readout_fc",y[0],batch=y[1]),list(zip(neurons,np.arange(len(neurons)))),0) Doing so gives me the same image 3 times. So, I was wondering if there is something wrong in how I am using the objective function to generate optimal stimuli from multiple units in parallel. Thanks in advance.

zitkat commented 3 years ago

Hi, I have been rendering optimal stimuli in batches without any issues. Instead of reduce you can use sum, like so

image_size=(50,)
indcs_batch = [10, 20, 30]
layer="readout_fc"

batch_param_f = lambda: param.image(*image_size, batch=len(indcs_batch))
obj = sum([objectives.channel(layer, n, b) for b, n in enumerate(indcs_batch)])

This should be equivalent to your code though, so not sure what is wrong in your case.

arnaghosh commented 3 years ago

Great! Thank you. I'll try it out and get back to you.

arnaghosh commented 3 years ago

Awesome. Using sum instead of reduce worked! Thank you so much @zitkat