farewellthree / STAN

Official PyTorch implementation of the paper "Revisiting Temporal Modeling for CLIP-based Image-to-Video Knowledge Transferring"
Apache License 2.0
90 stars 3 forks source link

help: new train config for msrvtt and some lower metric #21

Closed Lucky-Light-Sun closed 4 months ago

Lucky-Light-Sun commented 4 months ago

Hi, can you share the new training configs for MSRVTT? I ran the reimplemented code without modifying anything except for batch=32 for 4 machines(so total bz=128). And got the best 45.6000 retrieval/R1 is achieved at epoch 8 which is a little lower than the paper mentioned.

So I just wonder whether the best result mentioned in paper is still achieved by the new implemented code or some thing wrong in my MSRVTT config. Here is the config in my local project. Looking forward to your reply.

Best wishes!

_base_ = '../../_base_/default_runtime.py'

import xyg_utils.xyg_config as xyg_config

pretrained_model="openai/clip-vit-base-patch32"
clip_weight = xyg_config.clip_weight_path   # xyg added
model = dict(
    type='CLIPSimilarity_split',
    visual_encoder=dict(type='VITCLIPPretrained_STAN', pretrained_model=pretrained_model, clip_weight=clip_weight),
    text_encoder=dict(type='CLIPTextPretrained', pretrained_model=pretrained_model, clip_weight=clip_weight),
    to_float32=True,
    frozen_layers=False,
    data_preprocessor=dict(
        type='MultiModalDataPreprocessor',
        preprocessors=dict(
            imgs=dict(
                type='ActionDataPreprocessor',
                mean=[122.771, 116.746, 104.093],
                std=[68.500, 66.632, 70.323],
                format_shape='NCHW'),
            text=dict(type='ActionDataPreprocessor', to_float32=False))),
    tau = 0.01,
    adapter=None)

load_from = None #Path to the post-pretrained ckpt

dataset_type = 'MsrvttDataset'
# data_root = '/Path/to/your/msrvtt/dataset'
data_root = xyg_config.msrvtt_data_root
file_client_args = dict(io_backend='disk')
train_pipeline = [
    dict(type='DecordInit', **file_client_args),
    dict(type='UniformSample', clip_len=12, num_clips=1),
    dict(type='DecordDecode'),
    dict(type='Resize', scale=(-1, 256)),
    dict(type='RandomResizedCrop'),
    dict(type='Resize', scale=(224, 224), keep_ratio=False),
    dict(type='FormatShape', input_format='NCHW'),
    dict(type='CLIPTokenize', length=32),
    dict(type='PackActionInputs', collect_keys=('imgs', 'text'))
]
val_pipeline = [
    dict(type='DecordInit', **file_client_args),
    dict(type='UniformSample', clip_len=12, num_clips=1, test_mode=True),
    dict(type='DecordDecode'),
    dict(type='Resize', scale=(-1, 224)),
    dict(type='CenterCrop', crop_size=224),
    dict(type='FormatShape', input_format='NCHW'),
    dict(type='CLIPTokenize', length=32),
    dict(type='PackActionInputs', collect_keys=('imgs', 'text'))
]
test_pipeline = val_pipeline

train_dataloader = dict(
    batch_size=xyg_config.train_batch_size,
    num_workers=8,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=True),
    dataset=dict(
        type=dataset_type,
        ann_file='train_9k.json',
        data_root=data_root,
        data_prefix=dict(video='videos'),
        pipeline=train_pipeline))
val_dataloader = dict(
    batch_size=xyg_config.valid_batch_size,
    num_workers=8,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        ann_file='test_JSFUSION.json',
        data_root=data_root,
        data_prefix=dict(video='videos'),
        pipeline=val_pipeline,
        test_mode=True))
test_dataloader = dict(
    batch_size=16,
    num_workers=8,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        ann_file='test_JSFUSION.json',
        data_root=data_root,
        data_prefix=dict(video='videos'),
        pipeline=test_pipeline,
        test_mode=True))

val_evaluator = dict(type='RetrievalMetric')
test_evaluator = val_evaluator

train_cfg = dict(
    type='EpochBasedTrainLoop', max_epochs=10, val_begin=1, val_interval=1)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

param_scheduler = [
    dict(
        type='LinearLR',
        start_factor=0.05,
        by_epoch=True,
        begin=0,
        end=10,
        convert_to_iter_based=True),
    dict(
        type='CosineAnnealingLR',
        T_max=4.5,
        eta_min=0,
        by_epoch=True,
        begin=10,
        end=100,
        convert_to_iter_based=True)
]

optim_wrapper = dict(
    type='AmpOptimWrapper',
    optimizer=dict(
        type='AdamW',
        lr=2e-06,
        betas=(0.9, 0.98),
        eps=1e-08,
        weight_decay=0.05),
    paramwise_cfg=dict(
        norm_decay_mult=0., bias_decay_mult=0.,
        custom_keys={
            'STAN': dict(lr_mult=10.),
    }),
    clip_grad=dict(max_norm=5, norm_type=2)
)

default_hooks = dict(checkpoint=dict(type='printBest_CheckpointHook', interval=-1, save_best='auto', rule='greater'))

# Default setting for scaling LR automatically
#   - `enable` means enable scaling LR automatically
#       or not by default.
#   - `base_batch_size` = (8 GPUs) x (16 samples per GPU).
auto_scale_lr = dict(enable=False, base_batch_size=128)
Lucky-Light-Sun commented 4 months ago

Hi, After modifying the mean and std of data_preprocessor, max_epochs of train_cfg and weight_decay of optim_wrapper according to Issure: Can you share the training configs?, the best 46.0000 retrieval/R1 is achieved at epoch 8. But there is till a little lower than the paper mentioned.

So can you shared the new train config for the latest code?

Best wished!

model = dict(
    type='CLIPSimilarity_split',
    visual_encoder=dict(type='VITCLIPPretrained_STAN', pretrained_model=pretrained_model, clip_weight=clip_weight),
    text_encoder=dict(type='CLIPTextPretrained', pretrained_model=pretrained_model, clip_weight=clip_weight),
    to_float32=True,
    frozen_layers=False,
    data_preprocessor=dict(
        type='MultiModalDataPreprocessor',
        preprocessors=dict(
            imgs=dict(
                type='ActionDataPreprocessor',
                # mean=[122.771, 116.746, 104.093],
                # std=[68.500, 66.632, 70.323],
                mean=[123.675, 116.28, 103.53],
                std=[58.395, 57.12, 57.3751],
                format_shape='NCHW'),
            text=dict(type='ActionDataPreprocessor', to_float32=False))),
    tau = 0.01,
    adapter=None)

train_cfg = dict(
    type='EpochBasedTrainLoop', max_epochs=20, val_begin=1, val_interval=1)

optim_wrapper = dict(
    type='AmpOptimWrapper',
    optimizer=dict(
        type='AdamW',
        lr=2e-06,
        betas=(0.9, 0.98),
        eps=1e-08,
        # weight_decay=0.05),
        weight_decay=0.02),
    paramwise_cfg=dict(
        norm_decay_mult=0., bias_decay_mult=0.,
        custom_keys={
            'STAN': dict(lr_mult=10.),
    }),
    clip_grad=dict(max_norm=5, norm_type=2)
)
farewellthree commented 4 months ago

The fluctuation in retrieval results is relatively large, whereas fluctuations in report results within 1% are considered normal. For instance, in this code, our replication results for STAN are 0.5% lower than reported, while our replication for Mug-Stan is 1% higher than reported.