shivamsaboo17 / Overcoming-Catastrophic-forgetting-in-Neural-Networks

Elastic weight consolidation technique for incremental learning.
124 stars 22 forks source link

use torch.gather instead of direct indexing #5

Open afshinrahimi opened 3 years ago

afshinrahimi commented 3 years ago

Instead of this line:

log_liklihoods.append(output[:, target])

have this line:

log_liklihoods.append(torch.gather(output, dim=1, index=target.unsqueeze(-1)))

Why?

Assume our output is 100x4 which means batch size is 100 and we have 4 classes. Target is a (100,) vector of classes, by indexing output[:, target] we will create a 100x100 matrix, instead of gathering the loglikelihoods 100x1 that we desire.

The torch.gather function does this propoerly.