bdusell / semiring-einsum

Generic PyTorch implementation of einsum that supports different semirings
https://bdusell.github.io/semiring-einsum/
MIT License
45 stars 7 forks source link

document log_einsum and (renamed) log_viterbi_einsum #5

Closed davidweichiang closed 2 years ago

davidweichiang commented 2 years ago
bdusell commented 2 years ago

I think I agree with the point about _forward and _backward. The logic behind the naming was that _forward = non-differentiable, and no-_forward = normal differentiable function. The module still needs to export log_viterbi_einsum_forward to maintain API compatibility. Can we add a docstring to log_viterbi_einsum mentioning that it's a non-differentiable alias of log_viterbi_einsum_forward?

davidweichiang commented 2 years ago

Will do. But log_viterbi_einsum actually is (sub)differentiable, right?

bdusell commented 2 years ago

That's right, it just hasn't been implemented. The idea was to add log_viterbi_einsum in case log_viterbi_einsum_backward were ever implemented.

Actually, because of this I think it would be better not to make log_viterbi_einsum an alias of log_viterbi_einsum_forward. If someone calls backward on an expression that depends on the output of log_viterbi_einsum and another differentiable function, then the gradient will silently be "wrong" in the sense that error was not propagated through log_viterbi_einsum, which some users might never find out if they don't read the documentation very carefully.

Or, an acceptable alternative would be to set log_viterbi_einsum to a Function that raises NotImplementedError in the backward pass.

davidweichiang commented 2 years ago

OK I added NotImplementedError.