Closed LinXin04 closed 1 year ago
@jrzaurin I set like this:
wd_tab_and_text_model = WideDeep(wide=wide, deeptabular=tab_model, deeptext=text_model, pred_dim=num_class) from pytorch_widedeep.callbacks import EarlyStopping, ModelCheckpoint, LRHistory
early_stopping = EarlyStopping(monitor='val_f1', min_delta=0.001, patience=2, restore_best_weights=True, verbose=True ) from torch.optim import SGD, lr_scheduler from pytorch_widedeep.initializers import XavierNormal
tab_and_text_trainer = Trainer( wd_tab_and_text_model, objective="multiclass", callbacks=[early_stopping], metrics=[Accuracy, F1Score(average=True)], ) from pytorch_widedeep.dataloaders import DataLoaderImbalanced, DataLoaderDefault
tab_and_text_trainer.fit( X_train=X_train, X_val=X_val, n_epochs=10, batch_size=128 )
hello @LinXin04, maybe add the ModelCheckpoint callback for saving on the fly
early_stopping = EarlyStopping(
monitor="val_f1",
mode="max",
min_delta=0.001,
patience=10,
restore_best_weights=True,
verbose=2,
)
model_checkpoint = ModelCheckpoint(
monitor="val_f1",
mode="max",
filepath=f"{PROJECT_DIR}/checkpoint/MODEL.pt",
save_best_only=True,
max_save=1,
verbose=2,
)
trainer = Trainer(
model=wd_model,
objective="multiclass",
optimizers=optimizer,
callbacks=[ # pass both to callbacks
early_stopping,
model_checkpoint
],
metrics=[
Accuracy,
Precision,
Recall,
F1Score
],
num_workers=0,
)
thanks~ @thatalfredh After solving this problem, another overfitting problem occurred, manifested as follows:
from pytorch_widedeep.models import Wide, TabMlp, WideDeep wide = Wide(input_dim=np.unique(wd_X_wide_tr).shape[0], pred_dim=num_class)
tab_model = ContextAttentionMLP( column_idx=tab_preprocessor.column_idx, cat_embed_input=tab_preprocessor.cat_embed_input, cat_embed_activation="relu", continuous_cols=continuous_cols, cont_embed_activation="relu", with_addnorm=True, attn_dropout=0.5 )
text_model = StackedAttentiveRNN( vocab_size=len(text_preprocessor.vocab.itos), embed_dim=300, hidden_dim=128, n_blocks=2, padding_idx=0, rnn_type="gru", bidirectional=True, attn_concatenate=True, with_addnorm=True, head_hidden_dims=[128,64,32], attn_dropout=0.5, head_dropout=0.5 )
wd_tab_and_text_model = WideDeep(wide=wide, deeptabular=tab_model, deeptext=text_model, pred_dim=num_class)
@jrzaurin
Hey @LinXin04
Could you please elaborate a bit more on the "overfitting" problem?
I am not sure if you mean that there is a problem with the library, or the fact that you see that while the metric for the training set keeps increasing, the one for the validation set, does not, which in pple has nothing to do with the library. Is a common aspect that happens normally in many ML problems, and there are some techniques (like using learning rate schedulers) that you could try.
Also, if you think is a problem of the library, please, include the Trainer
if possible. :)
Let me know!
Hello, @LinXin04 I suppose you are dealing with a highly imbalanced data and now referring to the widening losses. @jrzaurin has made a nice solution for such a case. More technical information on the sampler used behind the scenes can be found here.
from pytorch_widedeep.dataloaders import DataLoaderImbalanced
# you obtain your training set somewhere here
trainer.fit(
X_tab=X_tab_train,
X_text=X_text_train,
X_img=X_img_train,
target=y_train,
n_epochs=20,
batch_size=32,
val_split=0.20,
custom_dataloader=DataLoaderImbalanced,
oversample_mul=20,
)
Hey @thatalfredh , thanks for the help 😉
@LinXin04 if that is the case, apart from the option that @thatalfredh wrote, you can also use a loss that is designed for imbalanced datasets, like The Focal Loss
thanks very much. Another issue is that when I combine tabular data and text data, the training effect of text data will be weakened.
For example, using this multimodal structure to train classification differs greatly from using text data alone to train Bert.
Why is this?
Hey @LinXin04
well, normally, if you have rich tabular data, unless the text adds significant, additional information to what is contained within the tabular data, it won't play that much of a relevant role. This is because extracting the information from tabular data is normally easier (as the information is logically divided into columns).
Let's go through an example. Say you are building a RecSys in fashion and you are using the text in the items descriptions. In addition, you are using tabular data where you have columns that are price, category, brand, texture, style, etc. Then, it is likely, that most of the information in the text is already present in the tabular data and logically divided into those cols. For a dataset like this, adding or not text to an existing tabular data might help marginally, but surely not significantly (if the text are reviews, the situation might change). In this particular case might be more beneficial to add images, as it is likely that those bring information that is not present directly in the tabular data.
Finally, one has to bear in mind that, when using multimodal models, the relevance or the requirement from each component is shared with the others, and therefore it becomes sometimes harder to understand which component is the most important (without performing a proper ablation study).
For example, think of team sports as opposed as individual sports. In individual sports is easier to isolate the strengths ands weaknesses of the athlete and work towards improving the performance. However, in team sports, one athlete might not be very good, but his/her contribution is not that important and is diluted among the contributions of all the other team mates. Something similar can happen here when you start adding model components.
I hope this helps!
got it! Thanks very much!
how to save the best epoch? The best option specified is not loss, but F1 score