diprism / fggs

Factor Graph Grammars in Python
MIT License
13 stars 3 forks source link

Compute sum-product of multiple nonterminals #71

Closed davidweichiang closed 2 years ago

davidweichiang commented 3 years ago

When computing gradients (#66) we are going to need the sum-product (\psi_X) of all the nonterminals, not just the start nonterminals. Probably there should be a new function called, I don't know, sum_product_all, that returns a dict mapping from nonterminals to their sum-products.

Alternatively, sum_product could take an option return_all=True which would cause it to return that dict.

davidweichiang commented 3 years ago

On second thought, maybe the best solution is to modify HRG so that it can have more than one start symbol.

darcey commented 3 years ago

Why is that a better solution? Are there other situations where we would want multiple start symbols?

kennethsible commented 3 years ago

Why is that a better solution? Are there other situations where we would want multiple start symbols?

@davidweichiang said that multiple start symbols will be useful in the adjoint operation that he implemented (#74).

davidweichiang commented 3 years ago

Yeah, it's good to discuss this because it's not yet clear what the right thing to do is. The cases I can think of right now are:

  1. If someone is interested in the sum-product of multiple nonterminals. For example, (2).
  2. In backpropagation, we want the gradient of the loss wrto all the factors. The adjoint_hrg() constructs a new FGG such that for each terminal label \ell, there's a new nonterminal \bar\ell whose sum-product is the gradient wrto \ell. So this is an instance of (1).
  3. In EM, similarly, we want the expected count of all the factors, which can be computed using adjoint_hrg() again.

A very minor argument would be that currently when an HRG is created, it has no start symbol (None). If someone forgets to set it, then errors can result. But if an HRG is actually defined to have zero or more start symbols, then it's not an error to leave the start symbol unset. (On the other hand, if we stay with exactly one start symbol, I think the HRG constructor should be changed to require a start symbol.)

davidweichiang commented 3 years ago

If we stick with exactly one start symbol and modify just sum_product to return multiple sum-products, what would the right interface be? Do we like functions that change their return type based on their arguments?

kennethsible commented 3 years ago

I would err on the side of not changing the return type based on a parameter (e.g. return_all). I think we would just need the main sum_product function and a wrapper function for the return_all behavior (e.g. sum_product_all).

davidweichiang commented 2 years ago

OK, on further reflection, we need a low-level interface to sum_product that has the following properties:

So I think we need a low-level interface like this:

def _sum_product(g: HRG, in_labels: Sequence[EdgeLabel], out_labels: Sequence[EdgeLabel], *in_values: Sequence[Tensor]) -> Tuple[Tensor]:
    """Compute the sum-product of the nonterminals in out_labels, given the sum-products of the terminals and/or nonterminals in in_labels and in_values.
    """

Then the real sum_product can be defined in terms of this. I can work on this, but does it seem like a reasonable interface to you (@kennethsible)?