Closed edwintenagyei367 closed 3 days ago
Hello, thank you for your interest in our work. The t-SNE visualization code is as follows:
num_class = 10
# Replace these paths with your own
model_our_path = '/path/to/adaptformer_svnh_gist_model_best.pth.tar'
import torch
from models import vision_transformer
from timm.models import create_model
from data import create_loader, create_dataset
# Initialize the model
model = create_model(
'vit_base_patch16_224_in21k',
pretrained=False,
num_classes=num_class,
scriptable=True,
checkpoint_path=model_our_path,
tuning_mode='gist_adapter',
)
# Move model to GPU
model = model.to('cuda')
model.eval()
# Load dataset
dataset_eval = create_dataset(
'svhn',
root='/path/to/vtab-1k/svhn',
split='train',
is_training=True,
class_map='',
download=False,
batch_size=1
)
# Create data loader
loader_eval = create_loader(
dataset_eval,
input_size=(3, 224, 224),
batch_size=1,
is_training=False,
use_prefetcher=True,
direct_resize=True,
interpolation='bicubic',
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
num_workers=8,
distributed=False,
crop_pct=0.875,
pin_memory=True,
)
# Evaluation loop
out_list = []
label_list = []
model.cuda().eval()
with torch.no_grad():
for batch_idx, (input, target) in enumerate(loader_eval):
output, gist, att_loss = model(input)
out_list.append(output.cpu().numpy())
label_list.append(target.cpu().numpy())
# t-SNE visualization
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
label_list_ = np.hstack([i for i in label_list])
cls_tokens_reshaped = np.vstack([i for i in out_list])
tsne_model = TSNE(perplexity=30, n_components=2, init='pca', random_state=42)
new_values = tsne_model.fit_transform(cls_tokens_reshaped)
plt.figure(figsize=(8, 8))
for label in np.unique(label_list_):
ix = np.where(label_list_ == label)
plt.scatter(new_values[ix, 0], new_values[ix, 1], label=label, alpha=0.5, s=88)
plt.xticks([])
plt.yticks([])
plt.axis('off')
# Replace with your desired save path
plt.savefig('/path/to/output.png', dpi=350, bbox_inches='tight')
plt.show()
Thank you so much. I appreciate the reply.
Thank you for the wonderful work. I would like to ask if there is any code implementation for the t-SNE visualization on the dataset.