diprism / fggs

Factor Graph Grammars in Python
MIT License
13 stars 3 forks source link

PatternedTensor breaks sum-product backward #173

Closed ccshan closed 10 months ago

ccshan commented 12 months ago
$ perpl/perplc perpl/examples/von_neumann.ppl | PYTHONPATH=fggs fggs/bin/sum_product.py -d -w unfair '[0.1,0.4]' -n unfair 0 /dev/stdin
[0.5000000000000003, 0.5000000000000003]

$ perpl/perplc perpl/examples/von_neumann.ppl | PYTHONPATH=fggs fggs/bin/sum_product.py -d -w unfair '[0.1,0.4]' -n unfair 0 -g /dev/stdin
[0.5000000000000003, 0.5000000000000003]
grad[unfair]: [-3.552713678800501e-15, 0.0]

$ perpl/perplc perpl/examples/von_neumann_pair.ppl | PYTHONPATH=fggs fggs/bin/sum_product.py -d -w unfair '[0.1,0.4]' -n unfair 0 /dev/stdin
[0.5000000000000003, 0.5000000000000003]

$ perpl/perplc perpl/examples/von_neumann_pair.ppl | PYTHONPATH=fggs fggs/bin/sum_product.py -d -w unfair '[0.1,0.4]' -n unfair 0 -g /dev/stdin
[0.5000000000000003, 0.5000000000000003]
Traceback (most recent call last):
  File "/home/ccshan/u/rational/fmitf/fggs/bin/sum_product.py", line 107, in <module>
    f.backward()
  File "/home/ccshan/u/rational/fmitf/python/lib/python3.11/site-packages/torch/_tensor.py", line 488, in backward
    torch.autograd.backward(
  File "/home/ccshan/u/rational/fmitf/python/lib/python3.11/site-packages/torch/autograd/__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Function SumProductBackward returned an invalid gradient at index 4 - got [4] but expected shape compatible with [2, 2]

I guess SumProduct.backward needs to do something like calling fggs.indices.project?

chihyang commented 11 months ago

@ccshan See https://github.com/diprism/fggs/actions/runs/6463666334, there are 5 test cases failing to pass.

ccshan commented 11 months ago

@ccshan See https://github.com/diprism/fggs/actions/runs/6463666334, there are 5 test cases failing to pass.

Thanks! Wow! I fixed that.

Now to "optimize project(to_dense()) composition"...

chihyang commented 11 months ago

@ccshan See https://github.com/diprism/fggs/actions/runs/6463666334, there are 5 test cases failing to pass.

Thanks! Wow! I fixed that.

Now to "optimize project(to_dense()) composition"...

Will work on that this weekend.

ccshan commented 11 months ago

Now to "optimize project(to_dense()) composition"...

Will work on that this weekend.

I hope I'll get to it before you do!

I think Axis.unify needs to be used. And unfortunately, the best example of how to use Axis.unify is in PatternedTensor.einsum (which is called in PatternedTensor.mv, PatternedTensor.mm, and TestPatternedTensor.test_einsum).

chihyang commented 11 months ago

Now to "optimize project(to_dense()) composition"...

Will work on that this weekend.

I hope I'll get to it before you do!

I think Axis.unify needs to be used. And unfortunately, the best example of how to use Axis.unify is in PatternedTensor.einsum (which is called in PatternedTensor.mv, PatternedTensor.mm, and TestPatternedTensor.test_einsum).

Here is the latest CI result: https://github.com/diprism/fggs/actions/runs/6522580767

I didn't use Axis.unify but wrote another function reproject. This function takes out a value from its first argument if the virtual coordinate of that element is also a virtual coordinate corresponding to one of elements in its second argument's physical. Otherwise it defaults to 0. If a virtual coordinate of one of its second argument's physical element is not in the first one's physical, it also defaults to 0.

https://github.com/diprism/fggs/blob/f0bd5c88ba31c13c6995720a1d6d51ba3e042502/fggs/indices.py#L526-L562

ccshan commented 11 months ago

I think unification allows something like reproject to take place without computing out_indices...

ccshan commented 11 months ago

I don't think t2.physical should be involved at all in the returned PatternedTensor, like it is on line 560...

chihyang commented 11 months ago

It shouldn't. I don't know what I was thinking... The idea is to use the indices to filter the values from gradients. But since Axis.unify is used, reproject can be removed. Sorry for this stupid mistake!