Closed statmasterY closed 1 year ago
Did you have a look at this issue https://github.com/dreamquark-ai/tabnet/issues/382 ?
I have read it and I have solved the coding problem by using the 'for' loop as the code in #382. I just want to know if it supports GridSearchCV, which is thought to be more convenient.
Another question is that I am new to tabNET, it did not work when I add the parameter 'loss_fn=ccc_loss', which might cause the problem that all loss is 0. But I could not deal with it.
this is my updated code of ccc_loss
def ccc_loss(y_true, y_pred):
if isinstance(y_true, np.ndarray):
y_true = y_true.reshape(-1, 1)
y_pred = y_pred.reshape(-1, 1)
elif torch.is_tensor(y_true):
y_true = y_true.detach().numpy().reshape(-1, 1)
y_pred = y_pred.detach().numpy().reshape(-1, 1)
y_true_mean = np.mean(y_true)
y_pred_mean = np.mean(y_pred)
y_true_std = np.std(y_true)
y_pred_std = np.std(y_pred)
cov = np.mean((y_true - y_true_mean) * (y_pred - y_pred_mean))
rho = cov / (y_true_std * y_pred_std)
ccc = 2 * rho * y_true_std * y_pred_std / (
y_true_std ** 2 + y_pred_std ** 2 + (y_true_mean - y_pred_mean) ** 2
)
return ccc
I would appreciate it if you could give me some suggestions.
I think you are confusing metric function and loss function. The metrics are scores you want to monitor during evaluation, loss function is the function you want to minimize during training.
The most important thing about the loss function is that it must be derivable and written in torch, so that auto_grad can compute the gradients. Here you are detaching the predictions from the graph, so the model can't update the gardients based on the final error. You should have two separate functions for metrics and loss. I would advise you to have a look at this notebook which explains how to customize the different functions : https://github.com/dreamquark-ai/tabnet/blob/develop/customizing_example.ipynb
Thanks a lot!
Have you now solved your problem @statmasterY ?
I use the loss function or scorer I defined, but it does not work.
this is my loss function:
and this is my training code:
the output is
it tells me that patience =10 does not work and I could not find any good example that combines the GridSearchCV and tabNET online, so I propose my issue here.