havakv / pycox

Survival analysis with PyTorch
BSD 2-Clause "Simplified" License
781 stars 180 forks source link

TypeError: __init__() takes 1 positional argument but 4 were given #119

Closed yangxulin closed 2 years ago

yangxulin commented 2 years ago

Hello,havakv.First of all, thank you very much for developing such a useful package. However, when I run the Cox-Time code you provided, I get the following error: image image It seems that there is a bug in the line of CoxTime.label_transform(). My code is as follows (same as the code you provided):

import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn_pandas import DataFrameMapper
import torch
import torchtuples as tt
from pycox.datasets import metabric
from pycox.models import CoxTime
from pycox.models.cox_time import MLPVanillaCoxTime
from pycox.evaluation import EvalSurv
np.random.seed(1234)
_ = torch.manual_seed(123)
df_train = metabric.read_df()
df_test = df_train.sample(frac=0.2)
df_train = df_train.drop(df_test.index)
df_val = df_train.sample(frac=0.2)
df_train = df_train.drop(df_val.index)
cols_standardize = ['x0', 'x1', 'x2', 'x3', 'x8']
cols_leave = ['x4', 'x5', 'x6', 'x7']
standardize = [([col], StandardScaler()) for col in cols_standardize]
leave = [(col, None) for col in cols_leave]
x_mapper = DataFrameMapper(standardize + leave)
x_train = x_mapper.fit_transform(df_train).astype('float32')
x_val = x_mapper.transform(df_val).astype('float32')
x_test = x_mapper.transform(df_test).astype('float32')
labtrans = CoxTime.label_transform()
get_target = lambda df: (df['duration'].values, df['event'].values)
y_train = labtrans.fit_transform(*get_target(df_train))
y_val = labtrans.transform(*get_target(df_val))
durations_test, events_test = get_target(df_test)
val = tt.tuplefy(x_val, y_val)

I would appreciate it if you could help me.

havakv commented 2 years ago

I think this has actually been fixed previously https://github.com/havakv/pycox/blob/master/pycox/preprocessing/label_transforms.py#L28, but it hasn't been pushed to pypi. I'll make a new version for it

havakv commented 2 years ago

Should be fixed now if you make sure you're using version 0.2.3

yangxulin commented 2 years ago

Should be fixed now if you make sure you're using version 0.2.3

Thank you very much for your reply, the problem was solved when I upgraded pycox to 0.2.3 .Thanks again for your patient reply.