havakv / pycox

Survival analysis with PyTorch
BSD 2-Clause "Simplified" License
803 stars 188 forks source link

Output the same for every input (survival curves) #140

Open tomas-silveira opened 2 years ago

tomas-silveira commented 2 years ago

Hello, so I've been trying to use Coxtime from your package to do churn analysis, but I'm getting some problems on the output, namely I'm getting the same output, in this case the same survival curve, for every instance. I don't know what's causing this, but my guess would be that the network's weights are all the same (couldn't find a way to check this).

Just for context, my dataset has 5 features and around 200k instances. For my network, I'm using 8 hidden layers with 256 nodes each, with batch normalization, dropout of 0.8, learning rate of 0.001 and 200 epochs (with early stop activated on the validation set). Have you had this problem before, and if so is there something that I should check to ensure that the network is being well trained?

Thank you in advance!

tomas-silveira commented 2 years ago

Also if you could help me out on another problem: one of the main metrics I'm using is the lift score, and as such I need to rank clients on their probability/risk of churning in different time windows (1month, 2months, etc). For other models, such as the Cox model, I've been ranking clients using H(t), but not so sure on what should I use in Cox-time... I've thought about getting S(t=x) for all the clients and rank for month 'x' based on which clients have the least survival probability, but since all survival curves are the same, this method is not the most reliable... Do you have any suggestions?

havakv commented 2 years ago

I know someone else have struggled with the net producing identical output #137 but I'm not sure what's causing it

havakv commented 2 years ago

For your second question I have to say that I don't have a good answer. There has been some discussions in #41 and #97 that touch on this, so you could have a look there