Hi, while running the main.py for resnet50 with 365 output classes, I ran into the following error
torch.Size([1, 2048]) torch.Size([365, 2048]) torch.Size([1, 1000])
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-3-02abbf51b5f9> in <module>
7 # Res = model.relprop(R = output * T, alpha= 1).sum(dim=1, keepdim=True)
8 #else:
----> 9 RAP = model.RAP_relprop(R=T)
10 Res = (RAP).sum(dim=1, keepdim=True)
11 # Check relevance value preserved
/home/SharedData3/ushasi/tub/gan/modules/resnet.py in RAP_relprop(self, R)
324 return R
325 def RAP_relprop(self, R):
--> 326 R = self.fc.RAP_relprop(R)
327 R = R.reshape_as(self.avgpool.Y)
328 R = self.avgpool.RAP_relprop(R)
/home/SharedData3/ushasi/tub/gan/modules/layers.py in RAP_relprop(self, R_p)
372 pd = R_p
373
--> 374 Rp_tmp = first_prop(pd, px, nx, pw, nw)
375 A = redistribute(Rp_tmp)
376
/home/SharedData3/ushasi/tub/gan/modules/layers.py in first_prop(pd, px, nx, pw, nw)
317 #print(px,pw)
**318 print(px.shape,pw.shape,pd.shape)**
--> 319 Rpp = F.linear(px, pw) * pd
320 Rpn = F.linear(px, nw) * pd
321 Rnp = F.linear(nx, pw) * pd
RuntimeError: The size of tensor a (365) must match the size of tensor b (1000) at non-singleton dimension 1
The top 3 shapes are a result of the print statement I added (in bold).
If I change my T = (T[:, np.newaxis] == np.arange(1000)) * 1.0 to T = (T[:, np.newaxis] == np.arange(365)) * 1.0 in compute_pred function, then the error goes away.
I just wanted to confirm that this is indeed the right way to fix the error, I hope I am not doing something random to get rid of the error and in the process, giving the wrong output.
Hi, while running the main.py for resnet50 with 365 output classes, I ran into the following error
The top 3 shapes are a result of the print statement I added (in bold). If I change my
T = (T[:, np.newaxis] == np.arange(1000)) * 1.0
toT = (T[:, np.newaxis] == np.arange(365)) * 1.0
in compute_pred function, then the error goes away. I just wanted to confirm that this is indeed the right way to fix the error, I hope I am not doing something random to get rid of the error and in the process, giving the wrong output.