google-research / scenic

Scenic: A Jax Library for Computer Vision Research and Beyond
Apache License 2.0
3.23k stars 426 forks source link

[MBT] Inference only #552

Open BDHU opened 1 year ago

BDHU commented 1 year ago

Are there any ways to bypass the data-preprocessing step for MBT ("Attention Bottlenecks for Multimodal Fusion") if I only wanna do inference without passing in the actual data from AS? I notice the main file requires config file from AudioSet. Is there any way to just perform inference with all-zero inputs? Never mind inference on the actual datasets as this is just PITA. Thanks!

WangzcBruce commented 1 year ago

Excuse, have you worker it out?

BDHU commented 1 year ago

Excuse, have you worker it out?

Unfortunately no:(

YingjieYin commented 1 year ago

Bad inference experience

BigJoon commented 1 year ago

Keep failing... just like.. "Fake" papers and code.

dipodaimary commented 1 year ago

Same here:

I made some progress however not able to complete it. Hopefully, this might help find some path and share if you are able to complete it, reach out to me if you need help explain how I made this progress

from typing import Any, Callable
from absl import flags
from clu import metric_writers
import jax
import jax.numpy as jnp
import ml_collections
from scenic import app
from scenic.projects.mbt import model
from scenic.projects.mbt import trainer
from scenic.train_lib import train_utils

import sys
sys.path.append("/generate_from_file/")
from generate_from_file import *
from pprint import pprint
from scenic.train_lib_deprecated import train_utils
from scenic.projects.mbt.configs.audioset import balanced_audioset_base
import ml_collections

AUDIOSET_TRAIN_SIZE = 20361

config = ml_collections.ConfigDict()
config.dataset_name = 'video_tfrecord_dataset'
config.experiment_name = 'mbt_balanced_audioset_classification'
config.dataset_configs = ml_collections.ConfigDict()
config.dataset_configs.base_dir = '/path/to/dataset'
config.dataset_configs.tables = {
  'train': 'balanced_train.se.melspec.tfrecord.sst@1024',
  'validation': 'eval.se.melspec.tfrecord.sst@1024',
  'test': 'temp',
}
config.dataset_configs.examples_per_subset = {
  'train': 20361,
  'validation': 18589,
  'test': 18589
}
config.dataset_configs.num_classes = 527
config.data_dtype_str = 'float32'
config.dataset_configs.modalities = ('spectrogram', 'rgb')
config.dataset_configs.return_as_dict = True
config.dataset_configs.num_frames = 8
config.dataset_configs.stride = 2
config.dataset_configs.num_spec_frames = 1
config.dataset_configs.spec_stride = 1

config.dataset_configs.spec_mean = 1.102
config.dataset_configs.spec_stddev = 2.762
config.dataset_configs.min_resize = 256
config.dataset_configs.crop_size = 224
config.dataset_configs.spec_shape = (100, 128)
config.dataset_configs.one_hot_labels = True
config.dataset_configs.zero_centering = True

config.dataset_configs.do_multicrop_test = True
config.dataset_configs.log_test_epochs = 4
# The effective batch size per host when testing is
# num_test_clips * test_batch_size
config.dataset_configs.num_test_clips = 4
config.dataset_configs.test_batch_size = 8  # Needs to be num_local_devices
config.multicrop_clips_per_device = 2

config.dataset_configs.augmentation_params = ml_collections.ConfigDict()
config.dataset_configs.augmentation_params.do_jitter_scale = True
config.dataset_configs.augmentation_params.scale_min_factor = 0.9
config.dataset_configs.augmentation_params.scale_max_factor = 1.33
config.dataset_configs.augmentation_params.prob_scale_jitter = 1.0
config.dataset_configs.augmentation_params.do_color_augment = True
config.dataset_configs.augmentation_params.prob_color_augment = 0.8
config.dataset_configs.augmentation_params.prob_color_drop = 0.1

config.dataset_configs.prefetch_to_device = 2

# SpecAugment hyperparameters
config.dataset_configs.spec_augment = True
config.dataset_configs.spec_augment_params = ml_collections.ConfigDict()
config.dataset_configs.spec_augment_params.freq_mask_max_bins = 48
config.dataset_configs.spec_augment_params.freq_mask_count = 1
config.dataset_configs.spec_augment_params.time_mask_max_frames = 48
config.dataset_configs.spec_augment_params.time_mask_count = 4
config.dataset_configs.spec_augment_params.time_warp_max_frames = 1.0
config.dataset_configs.spec_augment_params.time_warp_max_ratio = 0
config.dataset_configs.spec_augment_params.time_mask_max_ratio = 0

# Model: MBT-base
config.model_name = 'mbt_classification'
config.model = ml_collections.ConfigDict()
# Supports 'rgb' and 'spectrogram'
config.model.modality_fusion = ('spectrogram', 'rgb')
config.model.use_bottleneck = False
config.model.test_with_bottlenecks = True
config.model.share_encoder = False
config.model.n_bottlenecks = 4
# Layer at which to fuse. '0' refers to early fusion, if fusion_layer is equal
# to model.num_layers, then there is no cross-modal attention in the transformer
# and CLS tokens for each modality are averaged right at the end.
config.model.fusion_layer = 8
config.model.hidden_size = 768
config.model.patches = ml_collections.ConfigDict()
config.model.attention_config = ml_collections.ConfigDict()
config.model.attention_config.type = 'spacetime'
config.model.num_heads = 12
config.model.mlp_dim = 3072
config.model.num_layers = 12
config.model.representation_size = None
config.model.classifier = 'gap'
# config.model.classifier = 'token'
config.model.attention_dropout_rate = 0.
config.model.dropout_rate = 0.
config.model_dtype_str = 'float32'

config.model.temporal_encoding_config = ml_collections.ConfigDict()
# 3d_conv is only used for RGB inputs.
config.model.temporal_encoding_config.method = '3d_conv'
# 32 frames for RGB. Conv filter is 8. So total of 4 frames at input
config.model.patches.size = [16, 16, 2]
config.model.temporal_encoding_config.kernel_init_method = 'central_frame_initializer'
config.model.temporal_encoding_config.n_sampled_frames = 4  # Unused here.

# Training.
config.trainer_name = 'mbt_trainer'
config.optimizer = 'momentum'
config.optimizer_configs = ml_collections.ConfigDict()
config.l2_decay_factor = 0
config.max_grad_norm = 1
config.label_smoothing = 0.3
config.num_training_epochs = 50
config.batch_size = 64
config.rng_seed = 0
# This does Mixup in the train loop. This is fast. But make sure that device
# batch size is more than 1. On a 4x4 TPU, this means that your batch size
# needs to be at least 64.
config.mixup = ml_collections.ConfigDict()
config.mixup.alpha = 0.5
config.mixmod = False
# Additional regularization
config.model.stochastic_droplayer_rate = 0.3

# Use ImageNet-21k-initialised model from big_vision checkpoint
config.init_from = ml_collections.ConfigDict()
config.init_from.model_config = None
# Download pretrained ImageNet checkpoints from here:
# https://github.com/google-research/scenic/tree/main/scenic/projects/baselines (checkpoint_format = 'scenic')  pylint: disable=line-too-long
# https://github.com/google-research/vision_transformer (checkpoint_format = 'big_vision')  pylint: disable=line-too-long
config.init_from.checkpoint_path = '/path_to_checkpoint_of_vit_b_16/mbtb32_as-500k_rgb-spec'
config.init_from.checkpoint_format = 'scenic'
config.init_from.model_config = ml_collections.ConfigDict()
config.init_from.model_config.model = ml_collections.ConfigDict()
config.init_from.model_config.model.classifier = 'gap'  # Specify if this is 'token' or 'gap'.  pylint: disable=line-too-long
config.init_from.restore_positional_embedding = True
config.init_from.restore_input_embedding = True
config.init_from.positional_embed_size_change = 'resize_tile'

# Learning rate.
steps_per_epoch = AUDIOSET_TRAIN_SIZE // config.batch_size
total_steps = config.num_training_epochs * steps_per_epoch
config.lr_configs = ml_collections.ConfigDict()
config.lr_configs.learning_rate_schedule = 'compound'
config.lr_configs.factors = 'constant * cosine_decay * linear_warmup'
config.lr_configs.warmup_steps = 2.5 * steps_per_epoch
config.lr_configs.steps_per_cycle = total_steps
config.lr_configs.base_learning_rate = 5e-1

# Logging.
config.write_summary = True
config.checkpoint = True  # Do checkpointing.
config.debug_train = False  # Debug mode during training.
config.debug_eval = False  # Debug mode during eval.
config.checkpoint_steps = 500  # Checkpoint more frequently than a val epoch.

def get_path(d, path=[]):
  # base case: d is not a dictionary
  if not isinstance(d, dict):
    return [path]
  # recursive case: d is a dictionary
  else:
    paths = []
    for k, v in d.items():
      # append k to path and recurse on v
      paths.extend(get_path(v, path + [k]))
    return paths

from flax import jax_utils
from scenic.train_lib import pretrain_utils

trainstate = pretrain_utils.restore_pretrained_checkpoint(checkpoint_path="/path_to_checkpoint_of_vit_b_16/mbtb32_as-500k_rgb-spec")
train_state = jax_utils.replicate(trainstate)

from scenic.projects.mbt import model as mbt_model
model_cls = mbt_model.MBTClassificationModel

dataset_meta = {'input_dtype': jax._src.numpy.lax_numpy.float32,
 'input_shape': {'rgb': (1, 32, 224, 224, 3), 'spectrogram': (1, 800, 128, 3)},
 'num_classes': 527,
 'num_eval_examples': 18589,
 'num_test_examples': 74356,
 'num_train_examples': 20361,
 'target_is_onehot': True}

model = model_cls(config, dataset_meta)
is_multilabel_model = (config.model_name == 'mbt_multilabel_classification')
input_shapes = dataset_meta['input_shape']
input_dtype = jnp.float32
input_spec = {modality: (input_shapes[modality], input_dtype)
              for modality in input_shapes
             }

import time
start = time.time()

rng = jax.random.PRNGKey(42)
data_rng, rng = jax.random.split(rng)
rng, init_rng = jax.random.split(rng)
data_rng, rng = jax.random.split(rng)

from scenic.projects.mbt import train_utils as mbt_train_utils
(params, model_state, num_trainable_params,
   gflops) = mbt_train_utils.initialize_model(
       model_def=model.flax_model,
       input_spec=input_spec,
       config=config,
       rngs=init_rng)
end =time.time()
print(f"Took {end-start} seconds to initialize the model.")

import flax
from collections import Counter
from pprint import pprint
# params = flax.core.unfreeze(train_state.params)
paths = get_path(flax.core.unfreeze(params))
pprint(Counter([type(eval('params["' + '"]["'.join(path) + '"]')) for path in paths]))
pprint(Counter([eval('params["' + '"]["'.join(path) + '"]').shape for path in paths]))

import flax

variables = {
      'params': params,
      **train_state.model_state
  }

import jax
key = jax.random.PRNGKey(758493)  # Random seed is explicit in JAX
model_input = {'rgb': jnp.ones((1, 32, 224, 224, 3), jnp.float32),
              'spectrogram': jnp.ones((1, 800, 128, 3), jnp.float32)}

logits = model.flax_model.apply(
    variables,
    model_input,
    train=False, mutable=False, debug=False)

# Even the model foot print does not match with the provided pretrained weights (?)

paths = get_path(flax.core.unfreeze(train_state.params))
pprint(Counter([type(eval('train_state.params["' + '"]["'.join(path) + '"]')) for path in paths]))
pprint(Counter([eval('train_state.params["' + '"]["'.join(path) + '"]').shape for path in paths]))

paths = get_path(flax.core.unfreeze(params))
pprint(Counter([type(eval('params["' + '"]["'.join(path) + '"]')) for path in paths]))
pprint(Counter([eval('params["' + '"]["'.join(path) + '"]').shape for path in paths]))
dipodaimary commented 1 year ago

[UPDATE] I have made some progress, but it has been done in a hacky way. Specifically, I had to update the 'model.py' (function called add_positional_embed) file to duplicate a layer in order to match the model definition with the provided pretrained model. However, I have not been able to prepare the audioset data exactly as required, so I have not yet obtained the results as reported in the paper. Have added the notebook and update model.py file mbt_github.zip