charlesq34 / pointnet

PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation
Other
4.64k stars 1.44k forks source link

Orthogonality for transform matrix of 'transform_net1' is not counted in the loss calculation. #273

Open hyunjinku opened 3 years ago

hyunjinku commented 3 years ago

Dear authors and all,

Hi, I just want to double check whether orthogonality for transform matrix of 'transform_net1' (which is 3x3 T-Net matrix for input transform) is counted in your loss calculation, since it does not look like so.

If you see the file in pointnet/models/pointnet_cls.py , you just added feature transform matrix (KxK) into end_points for enforcing its orthogonality, which does not include input transform matrix in it.

line#22-40, function get_model()

end_points = {}

with tf.variable_scope('transform_net1') as sc:
    transform = input_transform_net(point_cloud, is_training, bn_decay, K=3) #shouldn't this transform matrix be added to end_points too for calculating orthogonality loss?
point_cloud_transformed = tf.matmul(point_cloud, transform)
input_image = tf.expand_dims(point_cloud_transformed, -1)

net = tf_util.conv2d(input_image, 64, [1,3],
                     padding='VALID', stride=[1,1],
                     bn=True, is_training=is_training,
                     scope='conv1', bn_decay=bn_decay)
net = tf_util.conv2d(net, 64, [1,1],
                     padding='VALID', stride=[1,1],
                     bn=True, is_training=is_training,
                     scope='conv2', bn_decay=bn_decay)

with tf.variable_scope('transform_net2') as sc:
    transform = feature_transform_net(net, is_training, bn_decay, K=64)
end_points['transform'] = transform

line#82-88, function get_loss():

# Enforce the transformation as orthogonal matrix
transform = end_points['transform'] # BxKxK
K = transform.get_shape()[1].value
mat_diff = tf.matmul(transform, tf.transpose(transform, perm=[0,2,1]))
mat_diff -= tf.constant(np.eye(K), dtype=tf.float32)
mat_diff_loss = tf.nn.l2_loss(mat_diff) 
tf.summary.scalar('mat loss', mat_diff_loss)

Thanks for sparing your time for this issue in advance!