Closed Tensionteng closed 9 months ago
我对比了RecBole-GNN和RecBole的源码
只需要把load_data_and_model
函数复制到recbole_gnn/quick_start.py
即可
def load_data_and_model(model_file):
r"""Load filtered dataset, split dataloaders and saved model.
Args:
model_file (str): The path of saved model file.
Returns:
tuple:
- config (Config): An instance object of Config, which record parameter information in :attr:`model_file`.
- model (AbstractRecommender): The model load from :attr:`model_file`.
- dataset (Dataset): The filtered dataset.
- train_data (AbstractDataLoader): The dataloader for training.
- valid_data (AbstractDataLoader): The dataloader for validation.
- test_data (AbstractDataLoader): The dataloader for testing.
"""
import torch
checkpoint = torch.load(model_file)
config = checkpoint["config"]
init_seed(config["seed"], config["reproducibility"])
init_logger(config)
logger = getLogger()
logger.info(config)
dataset = create_dataset(config)
logger.info(dataset)
train_data, valid_data, test_data = data_preparation(config, dataset)
init_seed(config["seed"], config["reproducibility"])
model = get_model(config["model"])(config, train_data._dataset).to(config["device"])
model.load_state_dict(checkpoint["state_dict"])
model.load_other_parameter(checkpoint.get("other_parameter"))
return config, model, dataset, train_data, valid_data, test_data
我想要使用训练好的模型来预测,报错内容显示模型不存在
代码如下,照抄的recbole/run_example/case_study_example.py,模型是自己训练好的
报错如下
实验环境):
在
recbole
这个仓库下有两个issue,分别是issue1和issue2,给出的回复均为版本不是最新 我通过命令conda install -c aibox recbole
下载的版本如下recbole 1.2.0 py39_0 aibox
请问报错的原因也是"代码非最新版本"吗?