Open allenlittlestar opened 2 years ago
Bert的词向量维度是768,Glove的词向量维度是300,因此需要修改 M3TR.py中的self.embedding_dim,同时在网络中增加一层fc将维度从300转换到768。
首先感谢您的回复。 我修改M3TR.py中的self.embedding_dim=300,然后仿照class M3TR中def init(self, vit, model, num_classes):中的self.fc_transform = nn.Linear(196, self.embedding_dim)增加了一个全连接层 self.fc_transform1 = nn.Linear(self.embedding_dim, 768),并将该层写入def get_semantic_token(self, x):中,在原有的x = self.fc_transform(mask)下加入了x = self.fc_transform1(x)。
但是还是报一开始的错误,请问我修改的有什么问题吗?如果方便的话,能不能请您上传一版Glove版本的文件,非常感谢。
Bert的词向量维度是768,Glove的词向量维度是300,因此需要修改 M3TR.py中的self.embedding_dim,同时在网络中增加一层fc将维度从300转换到768。
您好。我将Glove的文件读取修改为字典格式解决了上述问题,但报以下错误,应该跟您所说的在网络中增加一层fc将维度从300转换到768有关,请问您能不能详细说一下,这个FC应该写在哪个文件的哪个地方呢?可以直接写成 self.fc_transform1 = nn.Linear(self.embedding_dim, 768)这种格式吗?写好的语句还需要在别的地方调用吗?
Traceback (most recent call last):
File "main.py", line 73, in
Traceback (most recent call last): File "main.py", line 73, in
main(args)
File "main.py", line 59, in main
model = get_model(num_classes, args)
File "/media/omnisky/data/wr/M3TR-master/models/init.py", line 12, in get_model
model = model_dict[args.model_name](vit, res101, num_classes)
File "/media/omnisky/data/wr/M3TR-master/models/M3TR.py", line 55, in init
self.sem_embedding = self.get_word_embedding(self.num_classes).detach()
File "/media/omnisky/data/wr/M3TR-master/models/M3TR.py", line 81, in get_word_embedding
loaded = torch.load(embedding_path)
File "/home/omnisky/anaconda2/envs/faster/lib/python3.7/site-packages/torch/serialization.py", line 595, in load
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File "/home/omnisky/anaconda2/envs/faster/lib/python3.7/site-packages/torch/serialization.py", line 766, in _legacy_load
if magic_number != MAGIC_NUMBER:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
您好,我使用配好的环境能够成功复现您初始代码,但是将M3TR.py中#embedding_path = './Bert1/voc_embeddings.pt' 和 #embedding_path = './Bert1/coco_embeddings.pt'分别改为embedding_path = './Glove/voc_glove_word2vec.pkl'及embedding_path = './Glove/coco_glove_word2vec.pkl'之后,报上面的错误,请问可以指导一下如何修改吗?