IBM / graph2tac

Graph-based neural tactic prediction models for Coq.
Apache License 2.0
8 stars 4 forks source link

Fix small bug in loss #125

Closed jasonrute closed 1 year ago

jasonrute commented 1 year ago

This fixes a subtle bug in the loss which unfortunately will change all the test numbers. The main idea is that in a part of the code I used x.shape[0] to get the batch size, but mean to use tf.shape(x). The first comes out as None which also was a valid input to where I was putting the batch value into. What this effectively mean is that in some circumstances the loss for the global and local arguments was constructing an array missing some of the last batch elements. Those elements would have had loss zero, so it wasn’t a big deal, except that the loss is calculated as the average of the list of losses, and in this case there were a few less elements in the list. I doubt it will effect the network performance too much (although who knows for sure), but it is something we need to fix before the refactor since the refactor will rewrite this part of the code.

I made this it's own PR since it is changing the tests by a lot in some cases.