NVlabs / MambaVision

Official PyTorch Implementation of MambaVision: A Hybrid Mamba-Transformer Vision Backbone
https://arxiv.org/abs/2407.08083
Other
704 stars 40 forks source link

Clarification on drop_path_rate Conflict Between Python Script and Shell Script in MambaVision Model #29

Closed ghangminyun closed 3 weeks ago

ghangminyun commented 3 weeks ago
# mamba_vision.py
@register_pip_model
@register_model
def mamba_vision_T(pretrained=False, **kwargs):
    model_path = kwargs.pop("model_path", "/tmp/mamba_vision_T.pth.tar")
    pretrained_cfg = resolve_pretrained_cfg('mamba_vision_T').to_dict()
    update_args(pretrained_cfg, kwargs, kwargs_filter=None)
    model = MambaVision(depths=[1, 3, 8, 4],
                        num_heads=[2, 4, 8, 16],
                        window_size=[8, 8, 14, 7],
                        dim=80,
                        in_dim=32,
                        mlp_ratio=4,
                        resolution=224,
                        drop_path_rate=0.2, **# This line is conflicted**
                        **kwargs) **# with this argument.**
    model.pretrained_cfg = pretrained_cfg
    model.default_cfg = model.pretrained_cfg
    if pretrained:
        if not Path(model_path).is_file():
            url = model.default_cfg['url']
            torch.hub.download_url_to_file(url=url, dst=model_path)
        model._load_state_dict(model_path)
    return model
# train.sh
DATA_PATH="/ImageNet/train"
MODEL=mamba_vision_T
BS=2
EXP=Test
LR=8e-4
WD=0.05
WR_LR=1e-6
DR=0.38 **# This line is conflicted**
MESA=0.25

In the mamba_vision.py file, the MambaVision model already has a drop_path_rateparameter declared. However, in the train.sh script, it also provides a DRvariable set to 0.38. Since DRcorresponds to drop_path_rate, which value will be used during training? Should I follow the drop_path_ratein the Python script, or the DRvalue specified in the shell script?

This cause an error in my environment (python=3.9 torch=2.2.0 cuda=12.1)

ahatamiz commented 3 weeks ago

Hi @ghangminyun I made some changes. You should be able to pass hyper-parameters with different values to the model.