Reza-Zhu / ACMMM23-Solution-MBEG

Workshop on UAVs in Multimedia: Capturing the World from a New Perspective. Reza Zhu's Solution: MBEG
MIT License
10 stars 1 forks source link

Downloading trained models #2

Open gmberton opened 1 year ago

gmberton commented 1 year ago

Thank you for the repo! I don't have a baidu account, is there any other way to download the model's weights? If I send you a Google Drive account, could you upload the weights there?

Reza-Zhu commented 1 year ago

Hi gmberton, Thank you for your interest in our model. I apologize for any inconvenience caused by baidu account. If you are able to provide me with a Google Drive account, I would be happy to upload the weights there for you to access instead.

email: m025120503@sues.edu.cn

gmberton commented 1 year ago

That would be perfect: could you upload it here? https://drive.google.com/drive/folders/1w58oK6D3SQrKYYtlYPjFxZpNymFvZEoh?usp=sharing

Reza-Zhu commented 1 year ago

Upload Completed, please check the drive.

gmberton commented 1 year ago

Thanks a lot! I am able to download the weights. I am wondering, how can I load them in the model? I mean for a weight like MBEG-B2-1652.pth, what model from model_.py should I initialize? My guess is that if I do something like this it should work

import torch
import model_
weights_path = "MBEG-B2-1652.pth"
weights = torch.load(weights_path)
model = model_...... # TODO
model.load_state_dict(weights)
Reza-Zhu commented 1 year ago
# model_.py Line 42

class EVA(nn.Module):
    def __init__(self, classes, drop_rate, share_weight=True):
        super(EVA, self).__init__()
        self.model_1 = timm.create_model('eva02_large_patch14_448.mim_m38m_ft_in22k_in1k',pretrained=True,
    num_classes=0)
        if share_weight:
            self.model_2 = self.model_1
        else:
            self.model_2 = timm.create_model('eva02_large_patch14_448.mim_m38m_ft_in22k_in1k',pretrained=True,
    num_classes=0)
        self.classifier = ClassBlock(1024, classes, drop_rate)
image

https://huggingface.co/timm/eva02_base_patch14_448.mim_in22k_ft_in22k_in1k (Check Model Comparison)

In model_.py, you should replace the string ('eva02_large_patch14_448.mim_m38m_ft_in22k_in1k') at self.model_1 = timm.create_model('eva02_large_patch14_448.mim_m38m_ft_in22k_in1k',pretrained=True, num_classes=0) to other EVA serise weights. For the correspondence between MBEG and EVA, please see Table1 in the paper. The name of EVA weight in timm could be find in above huggingface link.