kimhc6028 / relational-networks

Pytorch implementation of "A simple neural network module for relational reasoning" (Relational Networks)
https://arxiv.org/pdf/1706.01427.pdf
BSD 3-Clause "New" or "Revised" License
812 stars 160 forks source link

Object coordinates missing #4

Closed thomashenn closed 7 years ago

thomashenn commented 7 years ago

From the article in the "Dealing with pixels" case:

So, after convolving the image, each of the d^2 k-dimensional cells in the d × d feature maps was tagged with an arbitrary coordinate indicating its relative spatial position, and was treated as an object for the RN.

Also, the author (/u/asantoro) confirmed on reddit that objects were of the form: [x, y, v_1, v_2, ..., v_k] where k is the number of kernels and the range of the coordinates x,y doesn't matter. (Reddit link)

So I think in the model, object coordinates should be added to oi and oj. https://github.com/kimhc6028/relational-networks/blob/master/model.py#L53

for i in range(25):
    oi = x[:,:,i/5,i%5]
    for j in range(25):
        oj = x[:,:,j/5,j%5]
        x_ = torch.cat((oi,oj,qst), 1)
        x_ = self.g_fc1(x_)

I believe this should improve performance on questions where the spatial relationship between objects is important (closest, furthest, ...).

kimhc6028 commented 7 years ago

Oops, thank you for pointing out my mistake. I am working with it with first priority. Please wait a bit...

kimhc6028 commented 7 years ago

It seems object coordinates improves accuracy. The code is now fixed. Thank you!

thomashenn commented 7 years ago

Glad to hear that ! Thank you :+1: