Open vk-rrr opened 4 weeks ago
help!help!
解决了吗? 救命!
我是这么修改的,不知道对不对 def extract_query_feat(query_loader): net.eval() print('Extracting Query Feature...') start = time.time() ptr = 0
# query_feat_fc = np.zeros((nquery, pool_dim * 3))
query_feat_pool1 = np.zeros((nquery, pool_dim))
query_feat_fc1 = np.zeros((nquery, pool_dim))
query_feat_pool2 = np.zeros((nquery, pool_dim))
query_feat_fc2 = np.zeros((nquery, pool_dim))
query_feat_pool3 = np.zeros((nquery, pool_dim))
query_feat_fc3 = np.zeros((nquery, pool_dim))
with torch.no_grad():
for batch_idx, (input, label) in enumerate(query_loader):
batch_num = input.size(0)
input = Variable(input.cuda())
feat_pool, feat_fc = net(input, input, label, test_mode[1]) # [192, 2048], [192, 2048]
#
# feat_pool1, feat_pool2, feat_pool3 = torch.chunk(feat_pool, 3, dim=0)
# feat_fc1, feat_fc2, feat_fc3 = torch.chunk(feat_fc, 3, dim=0)
# feat_pool = torch.cat((feat_pool1, feat_pool2, feat_pool3), 1)
# feat_fc = torch.cat((feat_fc1, feat_fc2, feat_fc3), 1)
query_feat_pool1[ptr:ptr + batch_num, :] = feat_pool[:batch_num].detach().cpu().numpy()
query_feat_fc1[ptr:ptr + batch_num, :] = feat_fc[:batch_num].detach().cpu().numpy()
query_feat_pool2[ptr:ptr + batch_num, :] = feat_pool[batch_num:2 * batch_num].detach().cpu().numpy()
query_feat_fc2[ptr:ptr + batch_num, :] = feat_fc[batch_num:2 * batch_num].detach().cpu().numpy()
query_feat_pool3[ptr:ptr + batch_num, :] = feat_pool[batch_num * 2:].detach().cpu().numpy()
query_feat_fc3[ptr:ptr + batch_num, :] = feat_fc[batch_num * 2:].detach().cpu().numpy()
ptr = ptr + batch_num
print('Extracting Time:\t {:.3f}'.format(time.time() - start))
query_feat_pool = np.concatenate((query_feat_pool1, query_feat_pool2, query_feat_pool3), axis=0)
query_feat_fc = np.concatenate((query_feat_fc1, query_feat_fc2, query_feat_fc3), axis=0)
return query_feat_pool, query_feat_fc
tsne.py可视化里的代码怎么改呀,救救
运行extract.py报错 Traceback (most recent call last): File "/home/hzh/Github/LLCM-main/DEEN/extract.py", line 268, in
query_feat_pool, query_feat_fc = extract_query_feat(query_loader)
File "/home/hzh/Github/LLCM-main/DEEN/extract.py", line 138, in extract_query_feat
query_feat_pool[ptr:ptr + batch_num, :] = feat_pool.detach().cpu().numpy()
ValueError: could not broadcast input array from shape (192,2048) into shape (64,2048)
把通道维度连接成一个特征,但是可视化的图不对,请问可以给出符合DEEN的extract.py详细代码吗?
def extract_query_feat(query_loader): net.eval() print('Extracting Query Feature...') start = time.time() ptr = 0 query_feat_pool = np.zeros((nquery, pool_dim 3)) query_feat_fc = np.zeros((nquery, pool_dim 3)) with torch.no_grad(): for batch_idx, (input, label) in enumerate(query_loader): batch_num = input.size(0) input = Variable(input.cuda()) feat_pool, feat_fc = net(input, input, label, test_mode[1]) # [192, 2048], [192, 2048]