I think the current implementation of CoxLoss is not accurate, below I discuss the two issues I identified and their solutions:
This lineloss_cox = -torch.mean((theta - torch.log(torch.sum(exp_theta*R_mat, dim=1))) * censor)
this part (theta - torch.log(torch.sum(exp_theta*R_mat, dim=1))) returns a 1d tensor, if censor is a 2d tensor (which is the case most of the time, as it is not squeezed in the function CoxLoss), then pytorch will have to broadcast, and this multiplication will return a 2d matrix.
Toy example:
theta = torch.tensor([-0.282,-0.1411,-0.1039,-0.0255])
exp_theta = torch.exp(theta)
R_mat = np.array([[1,0,1,1],[1,1,1,1],[0,0,1,1],[0,0,0,1]])
R_mat = torch.FloatTensor(R_mat)
censor = torch.tensor([[0,1,0,0]]).T
print(R_mat)
tensor([[1., 0., 1., 1.],
[1., 1., 1., 1.],
[0., 0., 1., 1.],
[0., 0., 0., 1.]])
print(censor.size())
torch.Size([4, 1])
print((theta - torch.log(torch.sum(exp_theta*R_mat, dim=1))))
tensor([-1.2491e+00, -1.3935e+00, -7.3312e-01, -2.2352e-08])
# cool, 1d vector, as expected. However, if censor is a 2d tensor, then it broadcasts
print((theta - torch.log(torch.sum(exp_theta*R_mat, dim=1))) * censor)
tensor([[-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],
[-1.2491e+00, -1.3935e+00, -7.3312e-01, -2.2352e-08],
[-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],
[-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00]])
# this results in an inaccurate loss calculation
The fix to this is to ensure censor is a 1d tensor, as I propose here.
Here you calculate the loss by averaging over the whole batch. I believe one should calculate the average over the subjects who experienced the event only.
Hello,
I think the current implementation of CoxLoss is not accurate, below I discuss the two issues I identified and their solutions:
loss_cox = -torch.mean((theta - torch.log(torch.sum(exp_theta*R_mat, dim=1))) * censor)
this part(theta - torch.log(torch.sum(exp_theta*R_mat, dim=1)))
returns a 1d tensor, ifcensor
is a 2d tensor (which is the case most of the time, as it is not squeezed in the functionCoxLoss
), then pytorch will have to broadcast, and this multiplication will return a 2d matrix.Toy example:
The fix to this is to ensure
censor
is a 1d tensor, as I propose here.Thanks