zhu-xlab / DeCUR

Decoupling common and unique representations for multimodal self-supervised learning
Apache License 2.0
41 stars 2 forks source link

Decoupled common and unique representations across two modalities visualized by t-SNE. #5

Closed StevenAZy closed 3 days ago

StevenAZy commented 4 days ago

Hello, impressive work! Could you explain how you generated Figure 1 to illustrate shared and distinct representations? Would you be able to share the code?

wangyi111 commented 3 days ago

Hi, thanks for your interest! Fig 1 was generated using t-SNE on a batch of data. Each dimension, instead of each sample, is seen as one data point for the plot. A simple code is as follows:

"""
model: DeCUR backbone + projector
data[0]: (B,C0,H,W), data[1]: (B,C1,H,W)
D: total embedding dim
Dc: common dim
"""
# get embeddings
z1,z2 = model.forward(data[0], data[1]) # (B,D), (B,D)
feature = torch.cat((z1[:,:Dc],z2[:,:Dc],z1[:,Dc:],z2[:,Dc:]),-1) # (B,2D)
feature = F.normalize(feature, dim=0)
# transpose embedding such that each dimension becomes one data point for tsne
tsne_feature = feature.permute(1,0) # (2D,B)

# define tsne categories
tsne_target = torch.zeros(2D,dtype=torch.int)
tsne_target [0:Dc] = 0 # common M1
tsne_target [Dc:2Dc] = 1 # common M2
tsne_target [2Dc:Dc+D] = 2 # unique M1
tsne_target [Dc+D:2D] = 3 # unique M2

# optionally reduce dim further by PCA
#pca = PCA(n_components=50)
#pca_result = pca.fit_transform(feature_data)
pca_result = feature_data 

# run tsne
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
tsne = TSNE(n_components=2, verbose=2, perplexity=40, n_iter=2000, n_iter_without_progress=300)
tsne_results = tsne.fit_transform(pca_result)

# plot
df = pd.DataFrame(tsne_feature )
df['y'] = tsne_target 
df['tsne-2d-one'] = tsne_results[:,0]
df['tsne-2d-two'] = tsne_results[:,1]
plt.figure(figsize=(12,7.5))
tsne_plot = sns.scatterplot(
    x="tsne-2d-one", y="tsne-2d-two",
    hue="y",
    palette=sns.color_palette(palette=['b','r','b','r'],n_colors=4),
    style='y',
    markers={0:'s',1:'X',2:'o',3:'o'},
    data=df,
    legend=None,
    alpha=0.8
)
StevenAZy commented 3 days ago

Thank you a lot.