diprism / fggs

Factor Graph Grammars in Python
MIT License
13 stars 3 forks source link

Small fixes to bin/sum_product.py #146

Closed davidweichiang closed 2 years ago

davidweichiang commented 2 years ago

This should be a really really quick review.

davidweichiang commented 2 years ago

Oh wait, I messed up and the 2nd commit didn't take.

davidweichiang commented 2 years ago

OK, this should be ready for review. I think it fixes the crash @HerbertMcSnout mentioned and should enable @ccshan to write and run programs.

colin-mcd commented 2 years ago

Hmm, I still get this stack trace when running sum_product on an FGG:

Traceback (most recent call last):
  File "bin/sum_product.py", line 28, in <module>
    print(json.dumps(fggs.formats.weights_to_json(fggs.sum_product(fgg, method=args.method))))
  File "/home/herbert/diprism/fgg-implementation/fggs/sum_product.py", line 369, in sum_produc
t                                                                                            
    comp_values = SumProduct.apply(fgg, comp_opts, inputs.keys(), comp_labels, *inputs.values(
))                                                                                           
  File "/home/herbert/diprism/fgg-implementation/fggs/sum_product.py", line 289, in forward
    x0.copy_(F(fgg, x0, inputs, semiring))
  File "/home/herbert/diprism/fgg-implementation/fggs/sum_product.py", line 80, in F
    tau_rule = sum_product_edges(interp, rule.rhs.nodes(), rule.rhs.edges(), rule.rhs.ext, x, 
inputs, semiring=semiring)                                                                   
  File "/home/herbert/diprism/fgg-implementation/fggs/sum_product.py", line 200, in sum_produc
t_edges                                                                                      
    out = semiring.einsum(compiled, *tensors)
  File "/home/herbert/diprism/fgg-implementation/fggs/semirings.py", line 121, in einsum
    return torch_semiring_einsum.semiring_einsum_forward(equation, args, 1, callback)
  File "/home/herbert/anaconda3/lib/python3.8/site-packages/torch_semiring_einsum/extend.py", 
line 110, in semiring_einsum_forward                                                         
    equation.validate_sizes(args)
  File "/home/herbert/anaconda3/lib/python3.8/site-packages/torch_semiring_einsum/equation.py"
, line 21, in validate_sizes                                                                 
    size = args[i].size(j)
AttributeError: 'list' object has no attribute 'size'
davidweichiang commented 2 years ago

That's strange, I was getting that error but this PR is supposed to fix that. Are you sure you're running the new version?

colin-mcd commented 2 years ago

Oh right, I checkout out the wrong branch! Sorry about that, it definitely works now.