Cerenaut / sparse-unsupervised-capsules

Sparse unsupervised capsules
https://arxiv.org/abs/1804.06094
Apache License 2.0
23 stars 5 forks source link

how to understand the reshape operation in calculate_routing_coefficients? #1

Closed zz12375 closed 5 years ago

zz12375 commented 6 years ago

Hi! Dear guys! I am a reader of your paper and code. In layers/layers.py, line 460, I noticed you used a reshape operation to split out the Height and Width dimensions of routing coefficient. But in capsule_model.py, line 67&68, the Height and Width dimensions were combined together in this order: [num_prime_capsules, Height, Width], so i believe the routing coefficient would hold on this order. So is it right to do tf.reshape(route, [-1, dim, dim, num_prime_capsules, route.shape[2].value])? or tf.reshape(route, [-1, num_prime_capsules, dim, dim, route.shape[2].value])?

drawlinson commented 6 years ago

Hi Really sorry for the slow response on this.

I will try to explain what we were doing. In models/layers/payers.py lines 460-464 it says:

routing_coeffs = tf.reshape(route, [-1, dim, dim, num_prime_capsules, route.shape[2].value]) capsule_routing = tf.reduce_max(routing_coeffs, axis=3) capsule_routing = tf.reduce_sum(capsule_routing, axis=2) return tf.reduce_sum(capsule_routing, axis=1)

The intent here is to find the total routing "support" for each latent / digit capsule, so we sum over routing from primary capsules. Between a pair of capsules layers, there will be some capsules in agreement, others not. So we took the Max over all the inbound routing coeffs from primary capsules so that the "most agreeing" capsules-pair relations would be considered. This is the intuition why it's a max (over primary caps) first, and then a sum over all the conv x,y locations.

If there is a bug here, and it's not picked up by simply the reshape() rules, I'd expect it to badly affect the routing consensus and make the results much worse. But we tracked consensus pretty closely and didn't observe that problem.

Having said that I spent a little time trying to figure out the dimensions of these tensors from looking at the code. I couldn't be 100% sure there's not a bug here, so you might be right. However unfortunately I don't have time to debug it as we're not actively using this code anymore.

Sorry I can't answer more completely. I think the only way to be sure is to debug and check the shapes all the way through step by step.