Closed davidweichiang closed 2 years ago
On second thought, maybe the best solution is to modify HRG so that it can have more than one start symbol.
Why is that a better solution? Are there other situations where we would want multiple start symbols?
Why is that a better solution? Are there other situations where we would want multiple start symbols?
@davidweichiang said that multiple start symbols will be useful in the adjoint operation that he implemented (#74).
Yeah, it's good to discuss this because it's not yet clear what the right thing to do is. The cases I can think of right now are:
A very minor argument would be that currently when an HRG is created, it has no start symbol (None). If someone forgets to set it, then errors can result. But if an HRG is actually defined to have zero or more start symbols, then it's not an error to leave the start symbol unset. (On the other hand, if we stay with exactly one start symbol, I think the HRG constructor should be changed to require a start symbol.)
If we stick with exactly one start symbol and modify just sum_product to return multiple sum-products, what would the right interface be? Do we like functions that change their return type based on their arguments?
I would err on the side of not changing the return type based on a parameter (e.g. return_all
). I think we would just need the main sum_product
function and a wrapper function for the return_all
behavior (e.g. sum_product_all
).
OK, on further reflection, we need a low-level interface to sum_product
that has the following properties:
So I think we need a low-level interface like this:
def _sum_product(g: HRG, in_labels: Sequence[EdgeLabel], out_labels: Sequence[EdgeLabel], *in_values: Sequence[Tensor]) -> Tuple[Tensor]:
"""Compute the sum-product of the nonterminals in out_labels, given the sum-products of the terminals and/or nonterminals in in_labels and in_values.
"""
Then the real sum_product can be defined in terms of this. I can work on this, but does it seem like a reasonable interface to you (@kennethsible)?
When computing gradients (#66) we are going to need the sum-product (\psi_X) of all the nonterminals, not just the start nonterminals. Probably there should be a new function called, I don't know,
sum_product_all
, that returns a dict mapping from nonterminals to their sum-products.Alternatively,
sum_product
could take an optionreturn_all=True
which would cause it to return that dict.