JCruan519 / GIST

(ACM MM24) This is the offical repository of GIST: Improving Parameter Efficient Fine Tuning via Knowledge Interaction.
Apache License 2.0
10 stars 0 forks source link

Visualization #1

Closed edwintenagyei367 closed 3 days ago

edwintenagyei367 commented 4 days ago

Thank you for the wonderful work. I would like to ask if there is any code implementation for the t-SNE visualization on the dataset.

JCruan519 commented 4 days ago

Hello, thank you for your interest in our work. The t-SNE visualization code is as follows:

num_class = 10

# Replace these paths with your own
model_our_path = '/path/to/adaptformer_svnh_gist_model_best.pth.tar'

import torch
from models import vision_transformer
from timm.models import create_model
from data import create_loader, create_dataset

# Initialize the model
model = create_model(
    'vit_base_patch16_224_in21k',
    pretrained=False,
    num_classes=num_class,
    scriptable=True,
    checkpoint_path=model_our_path,
    tuning_mode='gist_adapter',
)

# Move model to GPU
model = model.to('cuda')  
model.eval()

# Load dataset
dataset_eval = create_dataset(
    'svhn',
    root='/path/to/vtab-1k/svhn',
    split='train',
    is_training=True,
    class_map='',
    download=False,
    batch_size=1
)

# Create data loader
loader_eval = create_loader(
    dataset_eval,
    input_size=(3, 224, 224),
    batch_size=1,
    is_training=False,
    use_prefetcher=True,
    direct_resize=True,
    interpolation='bicubic',
    mean=(0.485, 0.456, 0.406),
    std=(0.229, 0.224, 0.225),
    num_workers=8,
    distributed=False,
    crop_pct=0.875,
    pin_memory=True,
)

# Evaluation loop
out_list = []
label_list = []
model.cuda().eval()

with torch.no_grad():
    for batch_idx, (input, target) in enumerate(loader_eval):
        output, gist, att_loss = model(input)
        out_list.append(output.cpu().numpy())
        label_list.append(target.cpu().numpy())

# t-SNE visualization
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

label_list_ = np.hstack([i for i in label_list])
cls_tokens_reshaped = np.vstack([i for i in out_list])

tsne_model = TSNE(perplexity=30, n_components=2, init='pca', random_state=42)
new_values = tsne_model.fit_transform(cls_tokens_reshaped)

plt.figure(figsize=(8, 8))
for label in np.unique(label_list_):
    ix = np.where(label_list_ == label)
    plt.scatter(new_values[ix, 0], new_values[ix, 1], label=label, alpha=0.5, s=88)
plt.xticks([])
plt.yticks([])
plt.axis('off')

# Replace with your desired save path
plt.savefig('/path/to/output.png', dpi=350, bbox_inches='tight')
plt.show()
edwintenagyei367 commented 3 days ago

Thank you so much. I appreciate the reply.