google / flaxformer

Apache License 2.0
321 stars 31 forks source link

How to run a simple inference on Switch base #5

Open younesbelkada opened 1 year ago

younesbelkada commented 1 year ago

Hi there!

First of all, awesome work on Switch transformers 🔥 I was wondering if there is a simple example script / commands to do a simple inference using switch_base model? Thanks !

younesbelkada commented 1 year ago

I finally managed to have a working script - for those who are interested you would need to:

1- Prepare the working setup:

git clone --branch=main https://github.com/google-research/t5x
cd t5x
python3 -m pip install -e '.[tpu]' -f \
  https://storage.googleapis.com/jax-releases/libtpu_releases.html

git clone https://github.com/google/flaxformer.git
cd flaxformer
pip3 install '.[testing]'

2- Get the model checkpoints:

export PATH_CHECKPOINTS=...
gcloud storage cp -r gs://t5-data/pretrained_models/t5x/moe/switch_classic/base/e8/checkpoint_500100 $PATH_CHECKPOINTS

3- Create and save a gin file for Switch Transformers similar to flaxformers (an example below)

# Switch Transformer Base model.
#
# Based on the original Switch Transformer (https://arxiv.org/abs/2101.03961).
#
# Note that unlike the original Switch Transformer, this T5X version does not
# use any jitter noise in the router.
#
# Provides MODEL and NUM_EXPERTS.

from __gin__ import dynamic_registration

from flaxformer.architectures.moe import moe_architecture
from flaxformer.architectures.moe import moe_layers
from flaxformer.architectures.moe import routing
from flaxformer.components import dense
import seqio
from t5x import adafactor

ARCHITECTURE = %gin.REQUIRED

include 'flaxformer/t5x/configs/moe/models/tokens_choose_base.gin'

# Architecture overrides
MLP_DIM = 3072

# MoE overrides
NUM_EXPERTS = 128
# Replace every other MLP sublayer is an MoE sublayer.
NUM_ENCODER_SPARSE_LAYERS = 6
NUM_DECODER_SPARSE_LAYERS = 6
TRAIN_EXPERT_CAPACITY_FACTOR = 1.25
EVAL_EXPERT_CAPACITY_FACTOR = 2.
NUM_SELECTED_EXPERTS = 1   # Switch routing
AUX_LOSS_FACTOR = 0.01
ROUTER_Z_LOSS_FACTOR = 0.0
GROUP_SIZE = 8192

# Switch Transformer Base uses relu activations.
dense.MlpBlock.activations = ('relu',)
expert/dense.MlpBlock.activations = ('relu',)

# Switch Transformer Base re-uses the token embedder to compute output logits.
moe_architecture.SparseDecoder.output_logits_factory = None

# Switch Transformer doesn't use BPR in encoder (although most sparse encoders
# generally see a boost from it).
sparse_encoder/routing.TokensChooseMaskedRouter.batch_prioritized_routing = False

Call it for example switch_base.gin and save it wherever you are happy to save it - (we'll refer it to PATH_GIN_BASE below)

4- Create an infer gin file for Switch Transfrormers:

from __gin__ import dynamic_registration

import __main__ as infer_script
from t5.data import mixtures
from t5x import partitioning
from t5x import utils

include "t5x/configs/runs/infer.gin"
# Here use $PATH_GIN_BASE
include "t5x/examples/moe/switch-c/switch_base.gin"

DROPOUT_RATE = 0.0  # unused but needs to be specified
MIXTURE_OR_TASK_NAME = "wmt_t2t_ende_v003"
TASK_FEATURE_LENGTHS = {"inputs": 64, "targets": 64}

partitioning.PjitPartitioner.num_partitions = 1

utils.DatasetConfig:
  split = "test"
  batch_size = 32

And save it somewhere and call it for example switch_base_infer.gin - and save it wherever (we'll refer it to PATH_INFER_GIN below)

5- Run the inference script!

Finally run the command: python -m t5x.infer --gin_file=$PATH_INFER_GIN --logtostderr --gin.MODEL_DIR=\"~/disk\" --gin.CHECKPOINT_PATH=\"/$PATH_CHECKPOINT\" --gin.INFER_OUTPUT_DIR=\"./\" --gin.NUM_MODEL_PARTITIONS=1 --gin.NUM_EXPERTS=8 Make sure the paths that are defined in the include are correct! Otherwise you'll get some errors