jrzaurin / pytorch-widedeep

A flexible package for multimodal-deep-learning to combine tabular data with text and images using Wide and Deep models in Pytorch
Apache License 2.0
1.3k stars 191 forks source link

Some additional self-supervised features #216

Closed s9021025292140 closed 4 months ago

s9021025292140 commented 4 months ago

Thank you for creating such a wonderful open-source project. I have a few questions:

  1. In my past projects, I didn't have labeled data, so I used TabNet (https://github.com/dreamquark-ai/tabnet) for self-supervised learning on my tabular data. I made some modifications to the open-source code to allow the SSL-trained TabNet to output only the encoded latent space during prediction (the embeddings obtained after the original data passes through the encoder, without going through the decoder). This helps us capture outliers or retrieve similar data from unlabeled tabular data in real-world datasets. I would like to ask if the TabNet in this project can also use the encoder to directly predict data without labels after SSL training? Additionally, can the SSL-trained TabNet provide an "explanin" function (currently, it seems only the supervised TabNet supports the explain function and can obtain embeddings from the original data predictions)?

  2. In your SSL tabular model, you also provide the SAINT architecture. Can SAINT, like the description above, be used to train unlabeled data with SSL and directly use the SSL-trained SAINT to predict the original data to obtain embeddings? Currently, it seems that only the supervised SAINT can output embeddings.

  3. In your provided examples, you can train different modalities of data such as tabular, text, and image in a multimodal manner, but it seems to require labeled data (as indicated by the target). My application requires training tabular and image data without labels, with the aim of finding the latent space between tabular and image data through SSL. Is it possible to provide a version for unlabeled multimodal training?

image

model = WideDeep( deeptabular=tab_mlp, deeptext=models_fuser, deepimage=vision, deephead=deephead, ) trainer = Trainer(model, objective="binary")

trainer.fit( X_tab=X_tab, X_text=[X_text_1, X_text_2], X_img=X_img, target=df["target"].values, n_epochs=1, batch_size=32, )

I apologize for requesting so much. Since I mostly deal with unlabeled data in practical applications, meeting the above requirements would greatly benefit many people. Once again, thank you for providing such a great project! :)

jrzaurin commented 4 months ago

Hey @s9021025292140 sorry for the late reply. I will try to keep up with the conversation this week if you reply :)

Let's go:

  1. "[...] I would like to ask if the TabNet in this project can also use the encoder to directly predict data without labels after SSL training? [...]": I am not sure what you mean. What is it that you want to predict? For example, let's assume the following pipeline: Self Supervised pre-training -> new data comes -> passes through the encoder but not the decoder -> you take the representations -> and then predict...what?
  2. "[...] Additionally, can the SSL-trained TabNet provide an "explanin" function (currently, it seems only the supervised TabNet supports the explain function and can obtain embeddings from the original data predictions)?": SSL does not involve a target per se, so there is no explain method in the classes in this module. If you extract embeddings and you predict given a target, then yeah, explain is available. Also, an scenario in which you pretrain (SS pretraining), you extract the learned embeddings and then use them for a supervised problem is possible.
  3. Finally, not, at the moment the library does not support multimodal SSL, the two methods available only support tabular (your comment reminds me that I need to include some self supervised examples in the README, thanks!)

So in summary: a pipeline where you use SSL and then extract the learned embeddings and do...whatever you would like with them, is possible :) . You just have to train using any of the two training classes ContrastiveDenoisingTrainer or EncoderDecoderTrainer then extract the weights and "proceed as usual". I might include a Pipeline object to do this automatically in the next release

And regarding images, widedeep at the moment does not support SSL for images and tabular data. One thing you could do is to encode/embed the images separately (using a restnet architecture or something light) and then treat the encoded image as categorical cols.

Let me know if this helps and I will try to add a bit more information during the day.

And thanks for using the library! (or considering it :) )

jrzaurin commented 4 months ago

Take a look here: https://github.com/jrzaurin/pytorch-widedeep/blob/8057360b712f7064b6669828d21d95361172b93e/examples/scripts/adult_census_enc_dec_full_example.py#L72

or here: https://github.com/jrzaurin/pytorch-widedeep/blob/8057360b712f7064b6669828d21d95361172b93e/examples/scripts/adult_census_cont_den_full_example.py#L89

on how to extract the encoder and from there on you can do anything you wanted with it.

s9021025292140 commented 4 months ago

Oh! Thank you very much for your reply!

  1. I apologize for not making myself clear. What I want is: Self Supervised pre-training -> new data comes -> passes through the encoder but not the decoder -> get representations (this is what I want to get, the whole process is just self-supervised learning pre-training). How can I modify your code to achieve the above purpose?

  2. The reason I asked this is because in pytorch-tabnet, I can directly use the explain() function on new data after using SSL pre-training. I haven’t read the author’s code carefully, so I’m not sure why this can be done...😅

  3. Got it! 👌

Finally, thank you very much for answering each of my questions with care😊.
I hope to complete my work through your library💪!

jrzaurin commented 4 months ago

hi @s9021025292140

here you have a fully functioning example for what you want:

import torch
import pandas as pd

from pytorch_widedeep.models import TabMlp
from pytorch_widedeep.datasets import load_adult
from pytorch_widedeep.preprocessing import TabPreprocessor
from pytorch_widedeep.self_supervised_training import EncoderDecoderTrainer

use_cuda = torch.cuda.is_available()

if __name__ == "__main__":

    # load the data and some preprocessing you probably don't need
    df: pd.DataFrame = load_adult(as_frame=True)
    df.columns = [c.replace("-", "_") for c in df.columns]
    df["income_label"] = (df["income"].apply(lambda x: ">50K" in x)).astype(int)
    df.drop("income", axis=1, inplace=True)

    # define the categorical and continuous columns, as well as the target
    cat_embed_cols = [
        "workclass",
        "education",
        "marital_status",
        "occupation",
        "relationship",
        "race",
        "gender",
        "capital_gain",
        "capital_loss",
        "native_country",
    ]
    continuous_cols = ["age", "hours_per_week"]
    target_col = "income_label"

    # instantiate the TabPreprocessor that will be use throughout the experiment
    tab_preprocessor = TabPreprocessor(
        cat_embed_cols=cat_embed_cols, continuous_cols=continuous_cols, scale=True
    )
    X_tab = tab_preprocessor.fit_transform(df)
    target = df[target_col].values

    # We define a model that will act as the encoder in the encoder/decoder
    # architecture. This could be any of: TabMlp, TabResnet or TabNet
    tab_mlp = TabMlp(
        column_idx=tab_preprocessor.column_idx,
        cat_embed_input=tab_preprocessor.cat_embed_input,
        continuous_cols=tab_preprocessor.continuous_cols,
    )

    # If we do not pass a custom decoder, which is perfectly possible via the
    # decoder param (see the docs or the examples notebooks, the
    # EncoderDecoderTrainer will automatically build a decoder which will be
    # the 'mirror' image of the encoder
    encoder_decoder_trainer = EncoderDecoderTrainer(encoder=tab_mlp)
    encoder_decoder_trainer.pretrain(X_tab, n_epochs=5, batch_size=256)

    # New data comes
    new_data = df.sample(32)

    # Preprocess the new data in the exact same way as the data used duting
    # the pre-training before
    new_X_tab_arr = tab_preprocessor.fit_transform(new_data)

    # Normally, the transformation to tensor happens inside the Trainer.
    # However, for what you want you just have to do it here
    new_X_tab_tnsr = torch.tensor(new_X_tab_arr).float()

    # And pass the tensor to the encoder (in eval model) to get the embeddings
    # here 'ed_model' stands for 'encoder_decoder_model'
    encoder = encoder_decoder_trainer.ed_model.encoder.eval()

    # # If you choose to save the pretrained model then
    # encoder_decoder_trainer.save(
    #     path="pretrained_weights", model_filename="encoder_decoder_model.pt"
    # )

    # # some time has passed, we load the model with torch as usual:
    # encoder_decoder_model = torch.load("pretrained_weights/encoder_decoder_model.pt")
    # encoder = encoder_decoder_model.encoder

    embeddings_1 = encoder(new_X_tab_tnsr)

    # or simply use tab_mlp, since as you remember, it was our encoder: 'encoder=tab_mlp'
    embeddings_2 = tab_mlp.eval()(new_X_tab_tnsr)
jrzaurin commented 4 months ago

And as regarding your point 2 on explain(), I will look at the source code of their implementation. It has been ages since I did, so I have to remind myself. But explain is normally related to explaining a prediction. I will have a look :)