RUCAIBox / RecBole-GNN

Efficient and extensible GNNs enhanced recommender library based on RecBole.
MIT License
167 stars 37 forks source link

`model_name` [xxx] is not the name of an existing model. #79

Closed Tensionteng closed 7 months ago

Tensionteng commented 7 months ago

我想要使用训练好的模型来预测,报错内容显示模型不存在

代码如下,照抄的recbole/run_example/case_study_example.py,模型是自己训练好的

import torch
from recbole.utils.case_study import full_sort_topk, full_sort_scores
from recbole.quick_start import load_data_and_model

config, model, dataset, train_data, valid_data, test_data = load_data_and_model(
    model_file="saved/XSimGCL-Jan-21-2024_02-59-55.pth",
)  # Here you can replace it by your model path.

    # uid_series = np.array([1, 2])  # internal user id series
    # or you can use dataset.token2id to transfer external user token to internal user id
uid_series = dataset.token2id(dataset.uid_field, ["1"])

topk_score, topk_iid_list = full_sort_topk(
    uid_series, model, test_data, k=10, device=config["device"]
)
print(topk_score)  # scores of top 10 items
print('top@10_item_ids',topk_iid_list)  # internal id of top 10 items
external_item_list = dataset.id2token(dataset.iid_field, topk_iid_list.cpu())
print(external_item_list)  # external tokens of top 10 items
print()

score = full_sort_scores(uid_series, model, test_data, device=config["device"])
print(score)  # score of all items
print(
    score[0, dataset.token2id(dataset.iid_field, ["242", "302"])]
)  # score of item ['242', '302'] for user '196'.

报错如下

ValueError: `model_name` [XSimGCL] is not the name of an existing model.

实验环境):

recbole这个仓库下有两个issue,分别是issue1issue2,给出的回复均为版本不是最新 我通过命令conda install -c aibox recbole下载的版本如下 recbole 1.2.0 py39_0 aibox 请问报错的原因也是"代码非最新版本"吗?

Tensionteng commented 7 months ago

我对比了RecBole-GNNRecBole的源码 只需要把load_data_and_model函数复制到recbole_gnn/quick_start.py即可

def load_data_and_model(model_file):
    r"""Load filtered dataset, split dataloaders and saved model.

    Args:
        model_file (str): The path of saved model file.

    Returns:
        tuple:
            - config (Config): An instance object of Config, which record parameter information in :attr:`model_file`.
            - model (AbstractRecommender): The model load from :attr:`model_file`.
            - dataset (Dataset): The filtered dataset.
            - train_data (AbstractDataLoader): The dataloader for training.
            - valid_data (AbstractDataLoader): The dataloader for validation.
            - test_data (AbstractDataLoader): The dataloader for testing.
    """
    import torch

    checkpoint = torch.load(model_file)
    config = checkpoint["config"]
    init_seed(config["seed"], config["reproducibility"])
    init_logger(config)
    logger = getLogger()
    logger.info(config)

    dataset = create_dataset(config)
    logger.info(dataset)
    train_data, valid_data, test_data = data_preparation(config, dataset)

    init_seed(config["seed"], config["reproducibility"])
    model = get_model(config["model"])(config, train_data._dataset).to(config["device"])
    model.load_state_dict(checkpoint["state_dict"])
    model.load_other_parameter(checkpoint.get("other_parameter"))

    return config, model, dataset, train_data, valid_data, test_data