RUCAIBox / RecBole-GNN

Efficient and extensible GNNs enhanced recommender library based on RecBole.
MIT License
171 stars 38 forks source link

有关Recbole_GNN框架中自定义模型的问题 #41

Closed DX12321 closed 2 years ago

DX12321 commented 2 years ago

您好,我在Recbole_GNN框架下运行自定义模型时报错【interaction中没有 x 这个key】 数据集和配置文件是CORE论文提供的。model 类中的myGCEGNN模型就是原GCEGNN模型,只是把代码复制出来命名model。 我调试了一下,发现main文件没有调用到 recbole_gnn\data\dataset.py 这个文件,这个文件是处理interaction的。

=== main文件代码 === from logging import getLogger from recbole.utils import init_logger, init_seed, set_color from recbole_gnn.config import Config from recbole_gnn.utils import create_dataset, data_preparation,get_model, get_trainer from model import myGCEGNN

if name == 'main':

configurations initialization

config = Config(
    model=myGCEGNN,
    dataset='diginetica',
    config_file_list=['config.yaml', 'config_model.yaml'],
)
init_seed(config['seed'], config['reproducibility'])
# logger initialization
init_logger(config)
logger = getLogger()
logger.info(config)
# dataset filtering
dataset = create_dataset(config)
logger.info(dataset)
# dataset splitting
train_data, valid_data, test_data = data_preparation(config, dataset)
# model loading and initialization
model = myGCEGNN(config, train_data.dataset).to(config['device'])
logger.info(model)
# trainer loading and initialization
trainer = get_trainer(config['MODEL_TYPE'], config['model'])(config, model)
# model training
best_valid_score, best_valid_result = trainer.fit(
    train_data, valid_data, saved=True, show_progress=config['show_progress']
)
# model evaluation
test_result = trainer.evaluate(test_data, load_best_model=True, show_progress=config['show_progress'])
logger.info(set_color('best valid ', 'yellow') + f': {best_valid_result}')
logger.info(set_color('test result', 'yellow') + f': {test_result}')
hyp1231 commented 2 years ago

感谢反馈!确实存在这个问题,因为目前 Dataset 和 DataLoader 的定位都不是特别灵活,推荐两种方法,选其一即可。我们后续会考虑如何优化自定义模型的使用,现在只能先这样凑合一下了。我在本地测试了一下方案二,是可以成功运行的。

方案一、直接开发新模型

  1. 把模型文件重命名为 mygcegnn.py 并放到 recbole_gnn/model/sequential_recommender/ 文件夹下;

  2. recbole_gnn/model/sequential_recommender/__init__.py 中加入

from recbole_gnn.model.sequential_recommender.mygcegnn import myGCEGNN
  1. recbole_gnn/data/dataset.py 中新建一个 class myGCEGNNDataset,直接继承 GCEGNNDataset

方案二、修改 main.py

  1. 把现在 main.py 中的
dataset = create_dataset(config)

替换成

from recbole_gnn.data.dataset import GCEGNNDataset
dataset = GCEGNNDataset(config)
train_data, valid_data, test_data = data_preparation(config, dataset)

替换成

from recbole_gnn.utils import _get_customized_dataloader
from recbole.data.utils import create_samplers

built_datasets = dataset.build()
train_dataset, valid_dataset, test_dataset = built_datasets
train_sampler, valid_sampler, test_sampler = create_samplers(config, dataset, built_datasets)

train_data = _get_customized_dataloader(config, 'train')(config, train_dataset, train_sampler, shuffle=True)
valid_data = _get_customized_dataloader(config, 'evaluation')(config, valid_dataset, valid_sampler, shuffle=False)
test_data = _get_customized_dataloader(config, 'evaluation')(config, test_dataset, test_sampler, shuffle=False)
DX12321 commented 2 years ago

您好,我这边按照您提供的解决方案试了一下,还是有点问题。 方案一和方案二都可以对数据集进行构图,但是在模型训练 forward 的时候还是会出现 interaction 中没有 'x' 这个key。 调试了一下,发现 main 处理完数据集后没有调用到 RecBole-GNN-main\recbole_gnn\data\transform.py文件,而这个文件是处理interaction的。在最后一次Constructing session graphs 后代码会出现一个 WARNING :

02 Jun 09:35 INFO Reversing sessions. 100%|██████████| 76589/76589 [00:01<00:00, 61923.70it/s] 02 Jun 09:36 INFO Constructing session graphs. 100%|██████████| 76589/76589 [00:22<00:00, 3331.92it/s] 02 Jun 09:36 WARNING Equal transform 02 Jun 09:36 WARNING Equal transform 02 Jun 09:36 WARNING Equal transform 02 Jun 09:36 INFO [Training]: train_batch_size = [100] negative sampling: [None]

应该是这个 WARNING 中断了对 interaction 的处理。

hyp1231 commented 2 years ago

您好,transform 的意思是在 DataLoader 里,把 Dataset 提供的 graph 建立 batch 的操作。

RecBole 的每个模型都会有自己的配置文件,GCE-GNN 对应的配置文件位置在 recbole_gnn/properties/model/GCEGNN.yaml,您把这个文件里对应的配置也输进去应该就没问题了。

我复现时看到您的 main.py 里加载了 config_model.yaml,就直接把 recbole_gnn/properties/model/GCEGNN.yaml 复制了一份变成了 config_model.yaml,所以就没遇到这个问题,您可以把里面的参数加上试试。

DX12321 commented 2 years ago

嗯嗯,找到问题了,我的config_model.yaml文件中缺少配置参数transform: sess_graph这个配置参数。 感谢感谢!!!