Open afshinrahimi opened 3 years ago
Instead of this line:
log_liklihoods.append(output[:, target])
have this line:
log_liklihoods.append(torch.gather(output, dim=1, index=target.unsqueeze(-1)))
Why?
Assume our output is 100x4 which means batch size is 100 and we have 4 classes. Target is a (100,) vector of classes, by indexing output[:, target] we will create a 100x100 matrix, instead of gathering the loglikelihoods 100x1 that we desire.
The torch.gather function does this propoerly.
Instead of this line:
log_liklihoods.append(output[:, target])
have this line:
log_liklihoods.append(torch.gather(output, dim=1, index=target.unsqueeze(-1)))
Why?
Assume our output is 100x4 which means batch size is 100 and we have 4 classes. Target is a (100,) vector of classes, by indexing output[:, target] we will create a 100x100 matrix, instead of gathering the loglikelihoods 100x1 that we desire.
The torch.gather function does this propoerly.