clabrugere / multitask-learning

Tensorflow implementation of three architectures for multi-task learning, a paradigm to learn different prediction tasks jointly using one model
MIT License
8 stars 0 forks source link
deep-learning machine-learning multitask-learning tensorflow

Multi-task learning

This repository contains the implementation three architectures for multi-task learning: shared bottom (a), mixture of experts (b) and multi-gate mixture of experts (c). Multi-task learning is a paradigm where one model learns different tasks jointly by sharing some of its parameters across tasks. It allows to save on resources (compute time, memory), reduce engineering complexity and points of failure of prediction pipelines, and even improve prediction performances for tasks that are correlated where information sharing is beneficial. Nevertheless, it can also suffer from negative transfer for tasks that are too different or with contradictory objectives.

One industry application of this paradigm is to model the funnel in advertising. Inputs are usually the same for CTR and CVR tasks: user and item characteristics but the feedback and sample space differ. One can encode special properties of the funnel directly into the architecture and the loss to improve the overall performance of the system:

Models

Architecture

In the implementation of this repository, gates are simple linear projection with a softmax activation. A temperature scaling in the softmax could be added to control the collective influence of experts. In addition, every tasks encoders have the same architecture for simplicity sake, but it can easily be adapted to fit more complex applications. Finally, a learnt linear projection is applied to the continuous vectors of inputs and before the concatenation with the learnt embeddings of discrete modalities, in order to project them into the same latent space.

An example of a simple multi-task loss is implemented in models/loss.py, to model multiple binary classification tasks.

Dependencies

Thie repository has the following dependencies:

Getting Started

git clone https://github.com/clabrugere/multitask-learning.git

Usage

# load your dataset
train_sparse_data = ...
train_dense_data = ...
train_labels = ...

model = MultiGateMixtureOfExperts(
   num_tasks=num_tasks,
   num_emb=num_embeddings,
   ...
)

# train the model
loss = MultiTaskBCE(num_tasks=num_tasks)
optimizer = tf.keras.optimizers.Adam()

model.compile(optimizer=optimizer, loss=loss)
model.fit(
   x=[train_sparse_data, train_dense_data],
   y=train_labels,
   epochs=20,
)

# make predictions
y_pred = model.predict(X_test)

References