A flexible package for multimodal-deep-learning to combine tabular data with text and images using Wide and Deep models in Pytorch
Documentation: https://pytorch-widedeep.readthedocs.io
Companion posts and tutorials: infinitoml
Experiments and comparison with LightGBM
: TabularDL vs LightGBM
Slack: if you want to contribute or just want to chat with us, join slack
The content of this document is organized as follows:
pytorch-widedeep
is based on Google's Wide and Deep Algorithm,
adjusted for multi-modal datasets.
In general terms, pytorch-widedeep
is a package to use deep learning with
tabular data. In particular, is intended to facilitate the combination of
text and images with corresponding tabular data using wide and deep models.
With that in mind there are a number of architectures that can be implemented
with the library. The main components of those architectures are shown in the
Figure below:
In math terms, and following the notation in the
paper, the expression for the architecture
without a deephead
component can be formulated as:
Where σ is the sigmoid function, 'W' are the weight matrices applied to the wide model and to the final activations of the deep models, 'a' are these final activations, φ(x) are the cross product transformations of the original features 'x', and , and 'b' is the bias term. In case you are wondering what are "cross product transformations", here is a quote taken directly from the paper: "For binary features, a cross-product transformation (e.g., “AND(gender=female, language=en)”) is 1 if and only if the constituent features (“gender=female” and “language=en”) are all 1, and 0 otherwise".
It is perfectly possible to use custom models (and not necessarily those in
the library) as long as the the custom models have an property called
output_dim
with the size of the last layer of activations, so that
WideDeep
can be constructed. Examples on how to use custom components can
be found in the Examples folder and the section below.
The pytorch-widedeep
library offers a number of different architectures. In
this section we will show some of them in their simplest form (i.e. with
default param values in most cases) with their corresponding code snippets.
Note that all the snippets below shoud run locally. For a more detailed
explanation of the different components and their parameters, please refer to
the documentation.
For the examples below we will be using a toy dataset generated as follows:
import os
import random
import numpy as np
import pandas as pd
from PIL import Image
from faker import Faker
def create_and_save_random_image(image_number, size=(32, 32)):
if not os.path.exists("images"):
os.makedirs("images")
array = np.random.randint(0, 256, (size[0], size[1], 3), dtype=np.uint8)
image = Image.fromarray(array)
image_name = f"image_{image_number}.png"
image.save(os.path.join("images", image_name))
return image_name
fake = Faker()
cities = ["New York", "Los Angeles", "Chicago", "Houston"]
names = ["Alice", "Bob", "Charlie", "David", "Eva"]
data = {
"city": [random.choice(cities) for _ in range(100)],
"name": [random.choice(names) for _ in range(100)],
"age": [random.uniform(18, 70) for _ in range(100)],
"height": [random.uniform(150, 200) for _ in range(100)],
"sentence": [fake.sentence() for _ in range(100)],
"other_sentence": [fake.sentence() for _ in range(100)],
"image_name": [create_and_save_random_image(i) for i in range(100)],
"target": [random.choice([0, 1]) for _ in range(100)],
}
df = pd.DataFrame(data)
This will create a 100 rows dataframe and a dir in your local folder, called
images
with 100 random images (or images with just noise).
Perhaps the simplest architecture would be just one component, wide
,
deeptabular
, deeptext
or deepimage
on their own, which is also
possible, but let's start the examples with a standard Wide and Deep
architecture. From there, how to build a model comprised only of one
component will be straightforward.
Note that the examples shown below would be almost identical using any of the
models available in the library. For example, TabMlp
can be replaced by
TabResnet
, TabNet
, TabTransformer
, etc. Similarly, BasicRNN
can be
replaced by AttentiveRNN
, StackedAttentiveRNN
, or HFModel
with
their corresponding parameters and preprocessor in the case of the Hugging
Face models.
1. Wide and Tabular component (aka deeptabular)
from pytorch_widedeep.preprocessing import TabPreprocessor, WidePreprocessor
from pytorch_widedeep.models import Wide, TabMlp, WideDeep
from pytorch_widedeep.training import Trainer
# Wide
wide_cols = ["city"]
crossed_cols = [("city", "name")]
wide_preprocessor = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols)
X_wide = wide_preprocessor.fit_transform(df)
wide = Wide(input_dim=np.unique(X_wide).shape[0])
# Tabular
tab_preprocessor = TabPreprocessor(
embed_cols=["city", "name"], continuous_cols=["age", "height"]
)
X_tab = tab_preprocessor.fit_transform(df)
tab_mlp = TabMlp(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
continuous_cols=tab_preprocessor.continuous_cols,
mlp_hidden_dims=[64, 32],
)
# WideDeep
model = WideDeep(wide=wide, deeptabular=tab_mlp)
# Train
trainer = Trainer(model, objective="binary")
trainer.fit(
X_wide=X_wide,
X_tab=X_tab,
target=df["target"].values,
n_epochs=1,
batch_size=32,
)
2. Tabular and Text data
from pytorch_widedeep.preprocessing import TabPreprocessor, TextPreprocessor
from pytorch_widedeep.models import TabMlp, BasicRNN, WideDeep
from pytorch_widedeep.training import Trainer
# Tabular
tab_preprocessor = TabPreprocessor(
embed_cols=["city", "name"], continuous_cols=["age", "height"]
)
X_tab = tab_preprocessor.fit_transform(df)
tab_mlp = TabMlp(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
continuous_cols=tab_preprocessor.continuous_cols,
mlp_hidden_dims=[64, 32],
)
# Text
text_preprocessor = TextPreprocessor(
text_col="sentence", maxlen=20, max_vocab=100, n_cpus=1
)
X_text = text_preprocessor.fit_transform(df)
rnn = BasicRNN(
vocab_size=len(text_preprocessor.vocab.itos),
embed_dim=16,
hidden_dim=8,
n_layers=1,
)
# WideDeep
model = WideDeep(deeptabular=tab_mlp, deeptext=rnn)
# Train
trainer = Trainer(model, objective="binary")
trainer.fit(
X_tab=X_tab,
X_text=X_text,
target=df["target"].values,
n_epochs=1,
batch_size=32,
)
3. Tabular and text with a FC head on top via the head_hidden_dims
param
in WideDeep
from pytorch_widedeep.preprocessing import TabPreprocessor, TextPreprocessor
from pytorch_widedeep.models import TabMlp, BasicRNN, WideDeep
from pytorch_widedeep.training import Trainer
# Tabular
tab_preprocessor = TabPreprocessor(
embed_cols=["city", "name"], continuous_cols=["age", "height"]
)
X_tab = tab_preprocessor.fit_transform(df)
tab_mlp = TabMlp(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
continuous_cols=tab_preprocessor.continuous_cols,
mlp_hidden_dims=[64, 32],
)
# Text
text_preprocessor = TextPreprocessor(
text_col="sentence", maxlen=20, max_vocab=100, n_cpus=1
)
X_text = text_preprocessor.fit_transform(df)
rnn = BasicRNN(
vocab_size=len(text_preprocessor.vocab.itos),
embed_dim=16,
hidden_dim=8,
n_layers=1,
)
# WideDeep
model = WideDeep(deeptabular=tab_mlp, deeptext=rnn, head_hidden_dims=[32, 16])
# Train
trainer = Trainer(model, objective="binary")
trainer.fit(
X_tab=X_tab,
X_text=X_text,
target=df["target"].values,
n_epochs=1,
batch_size=32,
)
4. Tabular and multiple text columns that are passed directly to
WideDeep
from pytorch_widedeep.preprocessing import TabPreprocessor, TextPreprocessor
from pytorch_widedeep.models import TabMlp, BasicRNN, WideDeep
from pytorch_widedeep.training import Trainer
# Tabular
tab_preprocessor = TabPreprocessor(
embed_cols=["city", "name"], continuous_cols=["age", "height"]
)
X_tab = tab_preprocessor.fit_transform(df)
tab_mlp = TabMlp(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
continuous_cols=tab_preprocessor.continuous_cols,
mlp_hidden_dims=[64, 32],
)
# Text
text_preprocessor_1 = TextPreprocessor(
text_col="sentence", maxlen=20, max_vocab=100, n_cpus=1
)
X_text_1 = text_preprocessor_1.fit_transform(df)
text_preprocessor_2 = TextPreprocessor(
text_col="other_sentence", maxlen=20, max_vocab=100, n_cpus=1
)
X_text_2 = text_preprocessor_2.fit_transform(df)
rnn_1 = BasicRNN(
vocab_size=len(text_preprocessor_1.vocab.itos),
embed_dim=16,
hidden_dim=8,
n_layers=1,
)
rnn_2 = BasicRNN(
vocab_size=len(text_preprocessor_2.vocab.itos),
embed_dim=16,
hidden_dim=8,
n_layers=1,
)
# WideDeep
model = WideDeep(deeptabular=tab_mlp, deeptext=[rnn_1, rnn_2])
# Train
trainer = Trainer(model, objective="binary")
trainer.fit(
X_tab=X_tab,
X_text=[X_text_1, X_text_2],
target=df["target"].values,
n_epochs=1,
batch_size=32,
)
5. Tabular data and multiple text columns that are fused via a the library's
ModelFuser
class
from pytorch_widedeep.preprocessing import TabPreprocessor, TextPreprocessor
from pytorch_widedeep.models import TabMlp, BasicRNN, WideDeep, ModelFuser
from pytorch_widedeep import Trainer
# Tabular
tab_preprocessor = TabPreprocessor(
embed_cols=["city", "name"], continuous_cols=["age", "height"]
)
X_tab = tab_preprocessor.fit_transform(df)
tab_mlp = TabMlp(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
continuous_cols=tab_preprocessor.continuous_cols,
mlp_hidden_dims=[64, 32],
)
# Text
text_preprocessor_1 = TextPreprocessor(
text_col="sentence", maxlen=20, max_vocab=100, n_cpus=1
)
X_text_1 = text_preprocessor_1.fit_transform(df)
text_preprocessor_2 = TextPreprocessor(
text_col="other_sentence", maxlen=20, max_vocab=100, n_cpus=1
)
X_text_2 = text_preprocessor_2.fit_transform(df)
rnn_1 = BasicRNN(
vocab_size=len(text_preprocessor_1.vocab.itos),
embed_dim=16,
hidden_dim=8,
n_layers=1,
)
rnn_2 = BasicRNN(
vocab_size=len(text_preprocessor_2.vocab.itos),
embed_dim=16,
hidden_dim=8,
n_layers=1,
)
models_fuser = ModelFuser(models=[rnn_1, rnn_2], fusion_method="mult")
# WideDeep
model = WideDeep(deeptabular=tab_mlp, deeptext=models_fuser)
# Train
trainer = Trainer(model, objective="binary")
trainer.fit(
X_tab=X_tab,
X_text=[X_text_1, X_text_2],
target=df["target"].values,
n_epochs=1,
batch_size=32,
)
6. Tabular and multiple text columns, with an image column. The text columns
are fused via the library's ModelFuser
and then all fused via the
deephead paramenter in WideDeep
which is a custom ModelFuser
coded by
the user
This is perhaps the less elegant solution as it involves a custom component by
the user and slicing the 'incoming' tensor. In the future, we will include a
TextAndImageModelFuser
to make this process more straightforward. Still, is not
really complicated and it is a good example of how to use custom components in
pytorch-widedeep
.
Note that the only requirement for the custom component is that it has a
property called output_dim
that returns the size of the last layer of
activations. In other words, it does not need to inherit from
BaseWDModelComponent
. This base class simply checks the existence of such
property and avoids some typing errors internally.
import torch
from pytorch_widedeep.preprocessing import TabPreprocessor, TextPreprocessor, ImagePreprocessor
from pytorch_widedeep.models import TabMlp, BasicRNN, WideDeep, ModelFuser, Vision
from pytorch_widedeep.models._base_wd_model_component import BaseWDModelComponent
from pytorch_widedeep import Trainer
# Tabular
tab_preprocessor = TabPreprocessor(
embed_cols=["city", "name"], continuous_cols=["age", "height"]
)
X_tab = tab_preprocessor.fit_transform(df)
tab_mlp = TabMlp(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
continuous_cols=tab_preprocessor.continuous_cols,
mlp_hidden_dims=[16, 8],
)
# Text
text_preprocessor_1 = TextPreprocessor(
text_col="sentence", maxlen=20, max_vocab=100, n_cpus=1
)
X_text_1 = text_preprocessor_1.fit_transform(df)
text_preprocessor_2 = TextPreprocessor(
text_col="other_sentence", maxlen=20, max_vocab=100, n_cpus=1
)
X_text_2 = text_preprocessor_2.fit_transform(df)
rnn_1 = BasicRNN(
vocab_size=len(text_preprocessor_1.vocab.itos),
embed_dim=16,
hidden_dim=8,
n_layers=1,
)
rnn_2 = BasicRNN(
vocab_size=len(text_preprocessor_2.vocab.itos),
embed_dim=16,
hidden_dim=8,
n_layers=1,
)
models_fuser = ModelFuser(
models=[rnn_1, rnn_2],
fusion_method="mult",
)
# Image
image_preprocessor = ImagePreprocessor(img_col="image_name", img_path="images")
X_img = image_preprocessor.fit_transform(df)
vision = Vision(pretrained_model_setup="resnet18", head_hidden_dims=[16, 8])
# deephead (custom model fuser)
class MyModelFuser(BaseWDModelComponent):
"""
Simply a Linear + Relu sequence on top of the text + images followed by a
Linear -> Relu -> Linear for the concatenation of tabular slice of the
tensor and the output of the text and image sequential model
"""
def __init__(
self,
tab_incoming_dim: int,
text_incoming_dim: int,
image_incoming_dim: int,
output_units: int,
):
super(MyModelFuser, self).__init__()
self.tab_incoming_dim = tab_incoming_dim
self.text_incoming_dim = text_incoming_dim
self.image_incoming_dim = image_incoming_dim
self.output_units = output_units
self.text_and_image_fuser = torch.nn.Sequential(
torch.nn.Linear(text_incoming_dim + image_incoming_dim, output_units),
torch.nn.ReLU(),
)
self.out = torch.nn.Sequential(
torch.nn.Linear(output_units + tab_incoming_dim, output_units * 4),
torch.nn.ReLU(),
torch.nn.Linear(output_units * 4, output_units),
)
def forward(self, X: torch.Tensor) -> torch.Tensor:
tab_slice = slice(0, self.tab_incoming_dim)
text_slice = slice(
self.tab_incoming_dim, self.tab_incoming_dim + self.text_incoming_dim
)
image_slice = slice(
self.tab_incoming_dim + self.text_incoming_dim,
self.tab_incoming_dim + self.text_incoming_dim + self.image_incoming_dim,
)
X_tab = X[:, tab_slice]
X_text = X[:, text_slice]
X_img = X[:, image_slice]
X_text_and_image = self.text_and_image_fuser(torch.cat([X_text, X_img], dim=1))
return self.out(torch.cat([X_tab, X_text_and_image], dim=1))
@property
def output_dim(self):
return self.output_units
deephead = MyModelFuser(
tab_incoming_dim=tab_mlp.output_dim,
text_incoming_dim=models_fuser.output_dim,
image_incoming_dim=vision.output_dim,
output_units=8,
)
# WideDeep
model = WideDeep(
deeptabular=tab_mlp,
deeptext=models_fuser,
deepimage=vision,
deephead=deephead,
)
# Train
trainer = Trainer(model, objective="binary")
trainer.fit(
X_tab=X_tab,
X_text=[X_text_1, X_text_2],
X_img=X_img,
target=df["target"].values,
n_epochs=1,
batch_size=32,
)
7. A two-tower model
This is a popular model in the context of recommendation systems. Let's say we have a tabular dataset formed my triples (user features, item features, target). We can create a two-tower model where the user and item features are passed through two separate models and then "fused" via a dot product.
import numpy as np
import pandas as pd
from pytorch_widedeep import Trainer
from pytorch_widedeep.preprocessing import TabPreprocessor
from pytorch_widedeep.models import TabMlp, WideDeep, ModelFuser
# Let's create the interaction dataset
# user_features dataframe
np.random.seed(42)
user_ids = np.arange(1, 101)
ages = np.random.randint(18, 60, size=100)
genders = np.random.choice(["male", "female"], size=100)
locations = np.random.choice(["city_a", "city_b", "city_c", "city_d"], size=100)
user_features = pd.DataFrame(
{"id": user_ids, "age": ages, "gender": genders, "location": locations}
)
# item_features dataframe
item_ids = np.arange(1, 101)
prices = np.random.uniform(10, 500, size=100).round(2)
colors = np.random.choice(["red", "blue", "green", "black"], size=100)
categories = np.random.choice(["electronics", "clothing", "home", "toys"], size=100)
item_features = pd.DataFrame(
{"id": item_ids, "price": prices, "color": colors, "category": categories}
)
# Interactions dataframe
interaction_user_ids = np.random.choice(user_ids, size=1000)
interaction_item_ids = np.random.choice(item_ids, size=1000)
purchased = np.random.choice([0, 1], size=1000, p=[0.7, 0.3])
interactions = pd.DataFrame(
{
"user_id": interaction_user_ids,
"item_id": interaction_item_ids,
"purchased": purchased,
}
)
user_item_purchased = interactions.merge(
user_features, left_on="user_id", right_on="id"
).merge(item_features, left_on="item_id", right_on="id")
# Users
tab_preprocessor_user = TabPreprocessor(
cat_embed_cols=["gender", "location"],
continuous_cols=["age"],
)
X_user = tab_preprocessor_user.fit_transform(user_item_purchased)
tab_mlp_user = TabMlp(
column_idx=tab_preprocessor_user.column_idx,
cat_embed_input=tab_preprocessor_user.cat_embed_input,
continuous_cols=["age"],
mlp_hidden_dims=[16, 8],
mlp_dropout=[0.2, 0.2],
)
# Items
tab_preprocessor_item = TabPreprocessor(
cat_embed_cols=["color", "category"],
continuous_cols=["price"],
)
X_item = tab_preprocessor_item.fit_transform(user_item_purchased)
tab_mlp_item = TabMlp(
column_idx=tab_preprocessor_item.column_idx,
cat_embed_input=tab_preprocessor_item.cat_embed_input,
continuous_cols=["price"],
mlp_hidden_dims=[16, 8],
mlp_dropout=[0.2, 0.2],
)
two_tower_model = ModelFuser([tab_mlp_user, tab_mlp_item], fusion_method="dot")
model = WideDeep(deeptabular=two_tower_model)
trainer = Trainer(model, objective="binary")
trainer.fit(
X_tab=[X_user, X_item],
target=interactions.purchased.values,
n_epochs=1,
batch_size=32,
)
8. Tabular with a multi-target loss
This one is "a bonus" to illustrate the use of multi-target losses, more than actually a different architecture.
from pytorch_widedeep.preprocessing import TabPreprocessor, TextPreprocessor, ImagePreprocessor
from pytorch_widedeep.models import TabMlp, BasicRNN, WideDeep, ModelFuser, Vision
from pytorch_widedeep.losses_multitarget import MultiTargetClassificationLoss
from pytorch_widedeep.models._base_wd_model_component import BaseWDModelComponent
from pytorch_widedeep import Trainer
# let's add a second target to the dataframe
df["target2"] = [random.choice([0, 1]) for _ in range(100)]
# Tabular
tab_preprocessor = TabPreprocessor(
embed_cols=["city", "name"], continuous_cols=["age", "height"]
)
X_tab = tab_preprocessor.fit_transform(df)
tab_mlp = TabMlp(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
continuous_cols=tab_preprocessor.continuous_cols,
mlp_hidden_dims=[64, 32],
)
# 'pred_dim=2' because we have two binary targets. For other types of targets,
# please, see the documentation
model = WideDeep(deeptabular=tab_mlp, pred_dim=2).
loss = MultiTargetClassificationLoss(binary_config=[0, 1], reduction="mean")
# When a multi-target loss is used, 'custom_loss_function' must not be None.
# See the docs
trainer = Trainer(model, objective="multitarget", custom_loss_function=loss)
trainer.fit(
X_tab=X_tab,
target=df[["target", "target2"]].values,
n_epochs=1,
batch_size=32,
)
deeptabular
componentIt is important to emphasize again that each individual component, wide
,
deeptabular
, deeptext
and deepimage
, can be used independently and in
isolation. For example, one could use only wide
, which is in simply a
linear model. In fact, one of the most interesting functionalities
inpytorch-widedeep
would be the use of the deeptabular
component on
its own, i.e. what one might normally refer as Deep Learning for Tabular
Data. Currently, pytorch-widedeep
offers the following different models
for that component:
Two simpler attention based models that we call:
The Tabformer
family, i.e. Transformers for Tabular data:
And probabilistic DL models for tabular data based on Weight Uncertainty in Neural Networks:
Wide
model.TabMlp
modelNote that while there are scientific publications for the TabTransformer, SAINT and FT-Transformer, the TabFasfFormer and TabPerceiver are our own adaptation of those algorithms for tabular data.
In addition, Self-Supervised pre-training can be used for all deeptabular
models, with the exception of the TabPerceiver
. Self-Supervised
pre-training can be used via two methods or routines which we refer as:
encoder-decoder method and constrastive-denoising method. Please, see the
documentation and the examples for details on this functionality, and all
other options in the library.
rec
moduleThis module was introduced as an extension to the existing components in the library, addressing questions and issues related to recommendation systems. While still under active development, it currently includes a select number of powerful recommendation models.
It's worth noting that this library already supported the implementation of various recommendation algorithms using existing components. For example, models like Wide and Deep, Two-Tower, or Neural Collaborative Filtering could be constructed using the library's core functionalities.
The recommendation algorithms in the rec
module are:
See the examples for details on how to use these models.
For the text component, deeptext
, the library offers the following models:
For the image component, deepimage
, the library supports models from the
following families:
'resnet', 'shufflenet', 'resnext', 'wide_resnet', 'regnet', 'densenet', 'mobilenetv3',
'mobilenetv2', 'mnasnet', 'efficientnet' and 'squeezenet'. These are
offered via torchvision
and wrapped up in the Vision
class.
Install using pip:
pip install pytorch-widedeep
Or install directly from github
pip install git+https://github.com/jrzaurin/pytorch-widedeep.git
# Clone the repository
git clone https://github.com/jrzaurin/pytorch-widedeep
cd pytorch-widedeep
# Install in dev mode
pip install -e .
Here is an end-to-end example of a binary classification with the adult
dataset
using Wide
and DeepDense
and defaults settings.
Building a wide (linear) and deep model with pytorch-widedeep
:
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from pytorch_widedeep import Trainer
from pytorch_widedeep.preprocessing import WidePreprocessor, TabPreprocessor
from pytorch_widedeep.models import Wide, TabMlp, WideDeep
from pytorch_widedeep.metrics import Accuracy
from pytorch_widedeep.datasets import load_adult
df = load_adult(as_frame=True)
df["income_label"] = (df["income"].apply(lambda x: ">50K" in x)).astype(int)
df.drop("income", axis=1, inplace=True)
df_train, df_test = train_test_split(df, test_size=0.2, stratify=df.income_label)
# Define the 'column set up'
wide_cols = [
"education",
"relationship",
"workclass",
"occupation",
"native-country",
"gender",
]
crossed_cols = [("education", "occupation"), ("native-country", "occupation")]
cat_embed_cols = [
"workclass",
"education",
"marital-status",
"occupation",
"relationship",
"race",
"gender",
"capital-gain",
"capital-loss",
"native-country",
]
continuous_cols = ["age", "hours-per-week"]
target = "income_label"
target = df_train[target].values
# prepare the data
wide_preprocessor = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols)
X_wide = wide_preprocessor.fit_transform(df_train)
tab_preprocessor = TabPreprocessor(
cat_embed_cols=cat_embed_cols, continuous_cols=continuous_cols # type: ignore[arg-type]
)
X_tab = tab_preprocessor.fit_transform(df_train)
# build the model
wide = Wide(input_dim=np.unique(X_wide).shape[0], pred_dim=1)
tab_mlp = TabMlp(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
continuous_cols=continuous_cols,
)
model = WideDeep(wide=wide, deeptabular=tab_mlp)
# train and validate
trainer = Trainer(model, objective="binary", metrics=[Accuracy])
trainer.fit(
X_wide=X_wide,
X_tab=X_tab,
target=target,
n_epochs=5,
batch_size=256,
)
# predict on test
X_wide_te = wide_preprocessor.transform(df_test)
X_tab_te = tab_preprocessor.transform(df_test)
preds = trainer.predict(X_wide=X_wide_te, X_tab=X_tab_te)
# Save and load
# Option 1: this will also save training history and lr history if the
# LRHistory callback is used
trainer.save(path="model_weights", save_state_dict=True)
# Option 2: save as any other torch model
torch.save(model.state_dict(), "model_weights/wd_model.pt")
# From here in advance, Option 1 or 2 are the same. I assume the user has
# prepared the data and defined the new model components:
# 1. Build the model
model_new = WideDeep(wide=wide, deeptabular=tab_mlp)
model_new.load_state_dict(torch.load("model_weights/wd_model.pt"))
# 2. Instantiate the trainer
trainer_new = Trainer(model_new, objective="binary")
# 3. Either start the fit or directly predict
preds = trainer_new.predict(X_wide=X_wide, X_tab=X_tab, batch_size=32)
Of course, one can do much more. See the Examples folder, the documentation or the companion posts for a better understanding of the content of the package and its functionalities.
pytest tests
Check CONTRIBUTING page.
This library takes from a series of other libraries, so I think it is just fair to mention them here in the README (specific mentions are also included in the code).
The Callbacks
and Initializers
structure and code is inspired by the
torchsample
library, which in
itself partially inspired by Keras
.
The TextProcessor
class in this library uses the
fastai
's
Tokenizer
and Vocab
. The code at utils.fastai_transforms
is a minor
adaptation of their code so it functions within this library. To my experience
their Tokenizer
is the best in class.
The ImageProcessor
class in this library uses code from the fantastic Deep
Learning for Computer
Vision
(DL4CV) book by Adrian Rosebrock.
This work is dual-licensed under Apache 2.0 and MIT (or any later version). You can choose between one of them if you use this work.
SPDX-License-Identifier: Apache-2.0 AND MIT
@article{Zaurin_pytorch-widedeep_A_flexible_2023,
author = {Zaurin, Javier Rodriguez and Mulinka, Pavol},
doi = {10.21105/joss.05027},
journal = {Journal of Open Source Software},
month = jun,
number = {86},
pages = {5027},
title = {{pytorch-widedeep: A flexible package for multimodal deep learning}},
url = {https://joss.theoj.org/papers/10.21105/joss.05027},
volume = {8},
year = {2023}
}
Zaurin, J. R., & Mulinka, P. (2023). pytorch-widedeep: A flexible package for
multimodal deep learning. Journal of Open Source Software, 8(86), 5027.
https://doi.org/10.21105/joss.05027