snorkel-team / snorkel

A system for quickly generating training data with weak supervision
https://snorkel.org
Apache License 2.0
5.81k stars 857 forks source link

question about `cross_entropy_with_probs` #1646

Closed ain-soph closed 3 years ago

ain-soph commented 3 years ago

The function cross_entropy_with_probs is to calculate CrossEntropy with logits (input) and a probability vector (target).
Let's not care about the weight and only focus on the input and target, then the formula is

-target*log(softmax(input))

So a natural thought is to calculate directly:

(-target * F.log_softmax(input, dim=1)).sum(1)

But from the current implementation Docs and Source Code
It seems to perform F.cross_entropy on each class to calculate log_softmax and sum up, which seems pretty weird. (I know the result is still correct.)

Anyone please tell me the advantage of doing so? I think it might be explained by the advantage of pytorch F.cross_entropy over DIY function, where the latter refers to

-F.log_softmax(input, dim=1).gather(dim=1, index=target.unsqueeze(1)).flatten()
github-actions[bot] commented 3 years ago

This issue is stale because it has been open 90 days with no activity. Remove stale label or comment or this will be closed in 7 days.