RUCAIBox / NCL

[WWW'22] Official PyTorch implementation for "Improving Graph Collaborative Filtering with Neighborhood-enriched Contrastive Learning".
119 stars 20 forks source link

about the distribution of item embeddings #8

Closed FinchNie closed 2 years ago

FinchNie commented 2 years ago

Hi, thanks for your great work. I am confused about Figure 6 when reading this paper.

We plot item embedding distributions with Gaussian kernel density estimation (KDE) in two-dimensional space image

Are codes for this figure available? Thank you.

hyp1231 commented 2 years ago

Hi,

We firstly train models with --embedding_size=2, then

import torch.nn.functional as F
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from recbole.quick_start import load_data_and_model

filepath = 'path/to/your/model'  # replace this to your path

config, model, dataset, train_data, valid_data, test_data = load_data_and_model(
    model_file=filepath,
)

item_emb = model.item_embedding.weight.cpu().detach()
item_emb = F.normalize(item_emb, dim=1).numpy()
print(item_emb.shape)

plt.figure(figsize=(3, 3))

df = pd.DataFrame({
    'x': item_emb.T[0],
    'y': item_emb.T[1]
})

ax = sns.kdeplot(
    data=df, x='x', y='y',
    thresh=0, levels=300, cmap=sns.color_palette('light:b', as_cmap=True)
)

plt.xlabel('')
plt.ylabel('')

plt.tight_layout()
plt.savefig('your pdf file name', format='pdf', dpi=300)  # replace this to your path
plt.show()
hyp1231 commented 2 years ago

Closing due to inactivity. Please comment if you're still having issues.

hyp1231 commented 2 years ago

您可以把从 RecBole 加载的 load_data_and_model 这个函数替换成下面这个函数试试

from ncl import NCL

def load_data_and_model(model_file):
    checkpoint = torch.load(model_file)
    config = checkpoint['config']
    init_seed(config['seed'], config['reproducibility'])
    init_logger(config)
    logger = getLogger()
    logger.info(config)

    dataset = create_dataset(config)
    logger.info(dataset)
    train_data, valid_data, test_data = data_preparation(config, dataset)

    init_seed(config['seed'], config['reproducibility'])
    model = NCL(config, train_data.dataset).to(config['device'])
    model.load_state_dict(checkpoint['state_dict'])
    model.load_other_parameter(checkpoint.get('other_parameter'))

    return config, model, dataset, train_data, valid_data, test_data