NVlabs / MambaVision

Official PyTorch Implementation of MambaVision: A Hybrid Mamba-Transformer Vision Backbone
https://arxiv.org/abs/2407.08083
Other
796 stars 43 forks source link

Questions about Train #37

Open wanan0414 opened 1 month ago

wanan0414 commented 1 month ago

I'm going to do a radar signal classification task, and using the pre-training weights from imagenet isn't appropriate. I also don't have enough dataset for pre-training here, is it possible to train directly with supervised learning using only a training set with 22000 data volume and a test set? Looking forward to your answer!

ahatamiz commented 1 month ago

Hi @wanan0414

Yes ! the codebase supports training on any datasets for classification tasks. Here's a command prompt assuming you are running on 8 GPUs (but it is also easy to extend for multi-node setup). You need to specify DATA_PATH_TRAIN and DATA_PATH_VAL for your train and validation (test) sets, respectively.

Although I recommend trying different learning rates LR and batch sizes BS for your setup.

Needless to say, any of our MambaVision variants can be specified such as mamba_vision_T.

#!/bin/bash

MODEL=mamba_vision_T
DATA_PATH_TRAIN="/my_dataset/train"
DATA_PATH_VAL="/my_dataset/val"
BS=256
EXP=my_experiment
LR=5e-4
WD=0.05
DR=0.2

torchrun --nproc_per_node=8 train.py --input-size 3 224 224 --crop-pct=0.875 \
--train-split=$DATA_PATH_TRAIN --val-split=$DATA_PATH_VAL --model $MODEL --amp --weight-decay ${WD} --drop-path ${DR} --batch-size $BS --tag $EXP --lr $LR

Hope this helps but let me know if there's any issue.

ahatamiz commented 1 month ago

The above also assumes that images of each class are placed within the same folder under both train or validation as shown in the following:

  ├── train
  │   ├── class1
  │   │   ├── img1.jpeg
  │   │   ├── img2.jpeg
  │   │   └── ...
  │   ├── class2
  │   │   ├── img3.jpeg
  │   │   └── ...
  │   └── ...
  └── val
      ├── class1
      │   ├── img4.jpeg
      │   ├── img5.jpeg
      │   └── ...
      ├── class2
      │   ├── img6.jpeg
      │   └── ...
      └── ...
wanan0414 commented 1 month ago

Hi @ahatamiz Thank you very much for your reply! My dataset is in the same format as the train-val you mentioned. Thank you for providing the schematic code. A new question I have is in that case do I not use validate.py and just focus on the train.py file?

ahatamiz commented 1 month ago

Hi @wanan0414

In this case, validate.py can be used to test the model's performance on your test set, assuming it's different from your validation set.