baaivision / EVA

EVA Series: Visual Representation Fantasies from BAAI
MIT License
2.24k stars 165 forks source link

[EVA-CLIP] How to load weights ? #54

Closed SimJeg closed 1 year ago

SimJeg commented 1 year ago

Hello,

Is there an easy way to load the model with the CLIP weights (I am interested by the visual part only) ? For instance :

import eva_clip
model = eva_clip.create_model('EVA-ViT-B-16-X', pretrained='EVA02_CLIP_B_psz16_s8B')

You would need to update the pretrained.py and factory.py to do so I think. What I have to do for now :

import eva_clip
import huggingface_hub

path = huggingface_hub.hf_hub_download('QuanSun/EVA-CLIP', 'EVA02_CLIP_B_psz16_s8B.pt')
state_dict = torch.load(path, map_location='cpu')
state_dict = {k.replace('visual.', ''): v for k, v in state_dict.items() if 'visual' in k}

model = eva_clip.create_model('EVA-ViT-B-16-X')
model.visual.load_state_dict(state_dict)

Also note that EVA weights will be soon integrated into timm (see this issue) which may simplify its use (no need to install apex, xformers or deepspeed anymore).

Best, Simon

Quan-Sun commented 1 year ago

Hi @SimJeg , Thank you for your interest in EVA-CLIP. We have updated the code to make it easier to load weights. Here is an example of how you can do it:

MODEL=EVA02-CLIP-B-16
PRETRAINED_IMAGE=eva
PRETRAINED_TEXT=openai
PRETRAINED_VISUAL_MODEL=EVA02-B-16
PRETRAINED_TEXT_MODEL=OpenaiCLIP-B-16

# Following OpenCLIP, we preprocess data by webdataset. We concat paths of LAION-2B and COYO-700M with `;`.

MERGE_2B_DATA_PATH="/path/to/laion2b_en_data/img_data/{000000..164090}.tar;/path/to/coyo700m_en_data/img_data/{000000..047435}.tar"
# LAION_2B_DATA_PATH="/path/to/laion2b_en_data/img_data/{000000..164090}.tar"
VAL_DATA_PATH=/path/to/IN-1K/val

cd rei

python -m torch.distributed.launch --nproc_per_node=8 \
        --nnodes=$WORLD_SIZE --node_rank=$RANK \
    --master_addr=$MASTER_ADDR --master_port=12355 --use_env \
    training/main.py \
        --save-frequency 1 \
        --zeroshot-frequency 1 \
        --report-to="wandb, tensorboard" \
        --wandb-project-name="eva-clip" \
        --wandb-notes="eva02_clip_B_16" \
        --train-num-samples 40000000 \
        --dataset-resampled \
        --train-data-list=${MERGE_2B_DATA_PATH} \
        --dataset-type-list="webdataset;webdataset" \
        --imagenet-val=${VAL_DATA_PATH} \
        --warmup 2000 \
        --batch-size=2048 \
        --epochs=200 \
        --lr=5e-4 \
        --visual-lr=2e-4 \
        --text-lr=2e-5 \
        --wd=0.05 \
        --visual-wd=0.05 \
        --text-wd=0.05 \
        --ld=1.0 \
        --visual-ld=0.75 \
        --text-ld=0.75 \
        --grad-clip-norm=5.0 \
        --smoothing=0. \
        --workers=8 \
        --model=${MODEL} \
        --pretrained-image=${PRETRAINED_IMAGE} \
        --pretrained-text=${PRETRAINED_TEXT} \
        --pretrained-visual-model=${PRETRAINED_VISUAL_MODEL} \
        --pretrained-text-model=${PRETRAINED_TEXT_MODEL} \
        --skip-list head.weight head.bias lm_head.weight lm_head.bias mask_token text_projection logit_scale \
        --seed 4096 \
        --gather-with-grad \
        --grad-checkpointing \
        --local-loss \
        --force-custom-clip \
        --force-patch-dropout=0 \
        --optimizer="lamb" \
        --zero-stage=1 \
        --enable-deepspeed
SimJeg commented 1 year ago

Thanks @Quan-Sun it works perfectly ! It could be interesting to update the README to include how to use the updated create_model{_and_transforms} for inference only.

Quan-Sun commented 1 year ago

@SimJeg Thank you for your suggestion, we will update README accordingly.

SimJeg commented 1 year ago

Thanks ! Feel free to close the issue after it

JeavanCode commented 1 year ago

Sorry to bother, but in the usage guidance in eva-clip, there is a line: "from eva_clip import create_model_and_transforms, get_tokenizer". However, I could not find such a file in the folder. I understand it might be obvious but I just can't figure it out. I found vision tower alone in timm but I would like to use the language tower as well.