Open KeremTurgutlu opened 3 years ago
Check out this pull request on
See visual diffs & provide feedback on Jupyter Notebooks.
Powered by ReviewNB
Hi, @KeremTurgutlu!
May I suggest another improvement to this PR? (I could have added another PR of my own, but I think resolving conflicts would be quite difficult then). The main thing that slows the computation down here is the use of batch size=1. Do you think we could get rid of this data_test_loader2
altogether and do something like (incorporating both your changes and larger batches):
imgs = []
intermediate_activations = []
total_correct = 0
model_loaded.eval()
with torch.no_grad():
for i, (images, labels) in tqdm(enumerate(data_test_loader), total=len(data_test_loader)): # <- note the data_test_loader here
imgs.append(images.view(images.shape[0], -1))
x = model_loaded.convnet(images)
intermediate_activations.append(x.view(x.shape[0], -1))
np.save("images", torch.cat(imgs).numpy())
np.save("intermediate_act", torch.cat(intermediate_activations).numpy())
This will make the code run another 4 times faster.
Another suggestion is to add
np.random.seed(0) # or any number you like
before each of the two cells that calculate MI, because as far as I can see the estimation algorithm is randomized, so the results can be different each time, especially if you only use 1000 data points. You might also want to change the expected results in the last markdown cell accordingly.
Thanks for the suggestion, I will update the PR once I have time!
before each of the two cells that calculate MI, because as far as I can see the estimation algorithm is randomized, so the results can be different each time, especially if you only use 1000 data points. You might also want to change the expected results in the last markdown cell accordingly.
Makes sense. I actually wanted to find another library which implements MI. If you know any feel free to share, the library used in the notebook require manual download I guess, and couldn't find it in pip or github either.
Description
numpy()
anddetach()
in a for loop over and over again is very slow and bad practice.torch.no_grad()
context to speed things up, since there is no need to calculate gradients, they are not needed..npy
file at each for loop iteration, so instead saved the final array at the end. This change allows speed up from 3 mins -> 10 secs.Affected Dependencies
None.
How has this been tested?