pronobis / libspn-keras

Library for learning and inference with Sum-product Networks utilizing TensorFlow 2.x and Keras
Other
47 stars 9 forks source link

DGC-SPN paper architecture #22

Closed yannickl96 closed 3 years ago

yannickl96 commented 3 years ago

Hello,

I am currently trying to replicate the SPN architecture from the DGC-SPN paper for image classification, however I struggle to apply the input dropout. Could you give me a hint if my code is in the right direction?

def def_spn():
    return keras.Sequential([
      spnk.layers.NormalizeStandardScore(input_shape=(28, 28, 1)),
      spnk.layers.NormalLeaf(
          num_components=32, 
          location_trainable=True,
          location_initializer=spnk.initializers.Equidistant(minval=-1.5, maxval=1.5)
      ),
      keras.layers.Dropout(0.2),
      # Non-overlapping products
      spnk.layers.Conv2DProduct(
          depthwise=True, 
          strides=[2, 2], 
          dilations=[1, 1], 
          kernel_size=[2, 2],
          padding='valid'
      ),
      spnk.layers.LogDropout(0.2),
      spnk.layers.Local2DSum(num_sums=64),
      # Non-overlapping products
      spnk.layers.Conv2DProduct(
          depthwise=True, 
          strides=[2, 2], 
          dilations=[1, 1], 
          kernel_size=[2, 2],
          padding='valid'
      ),
      spnk.layers.LogDropout(0.2),
      spnk.layers.Local2DSum(num_sums=64),
      # Overlapping products, starting at dilations [1, 1]
      spnk.layers.Conv2DProduct(
          depthwise=True, 
          strides=[1, 1], 
          dilations=[1, 1], 
          kernel_size=[2, 2],
          padding='full'
      ),
      spnk.layers.LogDropout(0.2),
      spnk.layers.Local2DSum(num_sums=64),
      # Overlapping products, with dilations [2, 2] and full padding
      spnk.layers.Conv2DProduct(
          depthwise=True, 
          strides=[1, 1], 
          dilations=[2, 2], 
          kernel_size=[2, 2],
          padding='full'
      ),
      spnk.layers.LogDropout(0.2),
      spnk.layers.Local2DSum(num_sums=128),
      # Overlapping products, with dilations [2, 2] and full padding
      spnk.layers.Conv2DProduct(
          depthwise=True, 
          strides=[1, 1], 
          dilations=[4, 4], 
          kernel_size=[2, 2],
          padding='full'
      ),
      spnk.layers.LogDropout(0.2),
      spnk.layers.Local2DSum(num_sums=128),
      # Overlapping products, with dilations [2, 2] and 'final' padding to combine 
      # all scopes
      spnk.layers.Conv2DProduct(
          depthwise=True, 
          strides=[1, 1], 
          dilations=[8, 8], 
          kernel_size=[2, 2],
          padding='final'
      ),
      spnk.layers.LogDropout(0.2),
      spnk.layers.SpatialToRegions(),
      # Class roots
      spnk.layers.DenseSum(num_sums=10),
      spnk.layers.RootSum(return_weighted_child_logits=True)
    ])

Any advice would be greatly appreciated!

Best regards!

jostosh commented 3 years ago

Hi Yannick,

Happy to help out in that case!

You're very close with the architecture. An important detail though is to have a specific shape for your dropout mask, one that corresponds to marginalizing out all components of a dropped-out variable simultaneously. In this case, that means that you have to add noise_shape = (28, 28, 1) to keras.layers.Dropout. The final axis in this case corresponds to components, and since the 'noise tensor' in the dropout layer will have just one value for that axis, it's going to be broadcasted to all channels/components. Hope this helps :)

yannickl96 commented 3 years ago

Hi Jos,

Thanks for your quick reply! Just one further question: In the paper you mention using the default Keras settings for the Adam optimizer (especially the learning rate of 1e-4), however, I noticed that the default settings in Keras are actually 1e-3. Did that change over time or is there a typo in the paper?

Cheers!

jostosh commented 3 years ago

Hi Yannick, Thanks for pointing it out, the value we've actually used was 1e-4, but there's a good chance that you get a similar performance with 1e-3.

yannickl96 commented 3 years ago

Thanks for the reply, worked like a charm! :)

jostosh commented 3 years ago

Will close it for now, but feel free to re-open if another related problem appears