jrzaurin / pytorch-widedeep

A flexible package for multimodal-deep-learning to combine tabular data with text and images using Wide and Deep models in Pytorch
Apache License 2.0
1.3k stars 191 forks source link

how to save the best Epoch #188

Closed LinXin04 closed 1 year ago

LinXin04 commented 1 year ago

image

how to save the best epoch? The best option specified is not loss, but F1 score

jrzaurin commented 1 year ago

Hey @LinXin04

Check here and you don't find a solution I will post some example later

Thanks!

LinXin04 commented 1 year ago

@jrzaurin I set like this:

wd_tab_and_text_model = WideDeep(wide=wide, deeptabular=tab_model, deeptext=text_model, pred_dim=num_class) from pytorch_widedeep.callbacks import EarlyStopping, ModelCheckpoint, LRHistory

early_stopping = EarlyStopping(monitor='val_f1', min_delta=0.001, patience=2, restore_best_weights=True, verbose=True ) from torch.optim import SGD, lr_scheduler from pytorch_widedeep.initializers import XavierNormal

tab_and_text_trainer = Trainer( wd_tab_and_text_model, objective="multiclass", callbacks=[early_stopping], metrics=[Accuracy, F1Score(average=True)], ) from pytorch_widedeep.dataloaders import DataLoaderImbalanced, DataLoaderDefault

tab_and_text_trainer.fit( X_train=X_train, X_val=X_val, n_epochs=10, batch_size=128 )

thatalfredh commented 1 year ago

hello @LinXin04, maybe add the ModelCheckpoint callback for saving on the fly

early_stopping = EarlyStopping(
    monitor="val_f1",
    mode="max",
    min_delta=0.001,
    patience=10,
    restore_best_weights=True,
    verbose=2,
)
model_checkpoint = ModelCheckpoint(
    monitor="val_f1",
    mode="max",
    filepath=f"{PROJECT_DIR}/checkpoint/MODEL.pt",
    save_best_only=True,
    max_save=1,
    verbose=2,
)
trainer = Trainer(
    model=wd_model,
    objective="multiclass",
    optimizers=optimizer,
    callbacks=[ # pass both to callbacks
        early_stopping, 
        model_checkpoint
    ],
    metrics=[
        Accuracy, 
        Precision,
        Recall,
        F1Score
    ],
    num_workers=0,
)
LinXin04 commented 1 year ago

thanks~ @thatalfredh After solving this problem, another overfitting problem occurred, manifested as follows: image

from pytorch_widedeep.models import Wide, TabMlp, WideDeep wide = Wide(input_dim=np.unique(wd_X_wide_tr).shape[0], pred_dim=num_class)

tab_model = ContextAttentionMLP( column_idx=tab_preprocessor.column_idx, cat_embed_input=tab_preprocessor.cat_embed_input, cat_embed_activation="relu", continuous_cols=continuous_cols, cont_embed_activation="relu", with_addnorm=True, attn_dropout=0.5 )

text_model = StackedAttentiveRNN( vocab_size=len(text_preprocessor.vocab.itos), embed_dim=300, hidden_dim=128, n_blocks=2, padding_idx=0, rnn_type="gru", bidirectional=True, attn_concatenate=True, with_addnorm=True, head_hidden_dims=[128,64,32], attn_dropout=0.5, head_dropout=0.5 )

wd_tab_and_text_model = WideDeep(wide=wide, deeptabular=tab_model, deeptext=text_model, pred_dim=num_class)

LinXin04 commented 1 year ago

@jrzaurin

jrzaurin commented 1 year ago

Hey @LinXin04

Could you please elaborate a bit more on the "overfitting" problem?

I am not sure if you mean that there is a problem with the library, or the fact that you see that while the metric for the training set keeps increasing, the one for the validation set, does not, which in pple has nothing to do with the library. Is a common aspect that happens normally in many ML problems, and there are some techniques (like using learning rate schedulers) that you could try.

Also, if you think is a problem of the library, please, include the Trainer if possible. :)

Let me know!

thatalfredh commented 1 year ago

Hello, @LinXin04 I suppose you are dealing with a highly imbalanced data and now referring to the widening losses. @jrzaurin has made a nice solution for such a case. More technical information on the sampler used behind the scenes can be found here.

from pytorch_widedeep.dataloaders import DataLoaderImbalanced

# you obtain your training set somewhere here

trainer.fit(
    X_tab=X_tab_train, 
    X_text=X_text_train, 
    X_img=X_img_train, 
    target=y_train, 
    n_epochs=20, 
    batch_size=32,
    val_split=0.20,
    custom_dataloader=DataLoaderImbalanced,
    oversample_mul=20,
)
jrzaurin commented 1 year ago

Hey @thatalfredh , thanks for the help 😉

@LinXin04 if that is the case, apart from the option that @thatalfredh wrote, you can also use a loss that is designed for imbalanced datasets, like The Focal Loss

LinXin04 commented 1 year ago

thanks very much. Another issue is that when I combine tabular data and text data, the training effect of text data will be weakened.

For example, using this multimodal structure to train classification differs greatly from using text data alone to train Bert.

Why is this?

jrzaurin commented 1 year ago

Hey @LinXin04

well, normally, if you have rich tabular data, unless the text adds significant, additional information to what is contained within the tabular data, it won't play that much of a relevant role. This is because extracting the information from tabular data is normally easier (as the information is logically divided into columns).

Let's go through an example. Say you are building a RecSys in fashion and you are using the text in the items descriptions. In addition, you are using tabular data where you have columns that are price, category, brand, texture, style, etc. Then, it is likely, that most of the information in the text is already present in the tabular data and logically divided into those cols. For a dataset like this, adding or not text to an existing tabular data might help marginally, but surely not significantly (if the text are reviews, the situation might change). In this particular case might be more beneficial to add images, as it is likely that those bring information that is not present directly in the tabular data.

Finally, one has to bear in mind that, when using multimodal models, the relevance or the requirement from each component is shared with the others, and therefore it becomes sometimes harder to understand which component is the most important (without performing a proper ablation study).

For example, think of team sports as opposed as individual sports. In individual sports is easier to isolate the strengths ands weaknesses of the athlete and work towards improving the performance. However, in team sports, one athlete might not be very good, but his/her contribution is not that important and is diluted among the contributions of all the other team mates. Something similar can happen here when you start adding model components.

I hope this helps!

LinXin04 commented 1 year ago

got it! Thanks very much!