a = labels.unique()
for i in range(len(a)):
for j in range(len(labels)):
if labels[j] == a[i]:
labels[j] = i
#print(labels)
data_time.update(time.time() - end)
out = torch.cat((input1, input2), 0)
tsne = manifold.TSNE(n_components=2, init='pca', random_state=0)
X_tsne = tsne.fit_transform(out.detach().cpu().numpy())
plot_embedding(X_tsne, labels, z)
plt.savefig(osp.join('save_tsne', 'tsne_{}.jpg'.format(batch_idx)))
for batch_idx, (input1, input2, label1, label2) in enumerate(trainloader): labels = torch.cat((label1, label2), 0) z1 = torch.ones(label1.shape) z2 = torch.zeros(label2.shape) z = torch.cat((z1, z2), 0) print(batch_idx) input1 = Variable(input1.cuda()) input2 = Variable(input2.cuda())