Closed EverOasis closed 11 months ago
您好,
可以检查一下pytorch_lightning和pytorch的版本,可能是pytorch版本过高,尝试降低pytorch版本到1.12.1
按照建议进行了降级,运行后还是出现错误,以下是错误信息:
Traceback (most recent call last):
File "main.py", line 121, in
以下是环境: Package Version
absl-py 1.4.0 aiohttp 3.8.4 aiosignal 1.3.1 alabaster 0.7.13 annotated-types 0.5.0 appdirs 1.4.4 asttokens 2.2.1 async-timeout 4.0.2 attrs 23.1.0 autopep8 2.0.2 Babel 2.12.1 backcall 0.2.0 Brotli 1.0.9 cachetools 5.3.1 certifi 2023.5.7 charset-normalizer 3.2.0 click 8.1.5 colorama 0.4.6 decorator 5.1.1 dgl-cu116 0.9.1 dglgo 0.0.2 docker-pycreds 0.4.0 docutils 0.20.1 executing 1.2.0 frozenlist 1.4.0 fsspec 2023.6.0 future 0.18.3 gitdb 4.0.10 GitPython 3.1.32 google-auth 2.22.0 google-auth-oauthlib 1.0.0 grpcio 1.56.0 idna 3.4 imagesize 1.4.1 importlib-metadata 6.8.0 ipython 8.12.2 isort 5.12.0 jedi 0.18.2 Jinja2 3.1.2 joblib 1.3.1 lightning-utilities 0.9.0 littleutils 0.2.2 Markdown 3.4.3 MarkupSafe 2.1.3 matplotlib-inline 0.1.6 multidict 6.0.4 networkx 3.1 neuralkg 1.0.21 numpy 1.21.0 numpydoc 1.5.0 oauthlib 3.2.2 ogb 1.3.6 outdated 0.2.2 packaging 23.1 pandas 2.0.3 parso 0.8.3 pathtools 0.1.2 pickleshare 0.7.5 Pillow 9.2.0 pip 23.1.2 prompt-toolkit 3.0.39 protobuf 4.23.4 psutil 5.9.5 pure-eval 0.2.2 pyasn1 0.5.0 pyasn1-modules 0.3.0 pycodestyle 2.10.0 pydantic 2.0.3 pydantic_core 2.3.0 pyDeprecate 0.3.1 Pygments 2.15.1 PySocks 1.7.1 python-dateutil 2.8.2 pytorch-lightning 1.5.10 pytz 2023.3 PyYAML 6.0.1 rdkit-pypi 2022.9.5 requests 2.31.0 requests-oauthlib 1.3.1 rsa 4.9 ruamel.yaml 0.17.32 ruamel.yaml.clib 0.2.7 scikit-learn 1.3.0 scipy 1.10.1 sentry-sdk 1.28.1 setproctitle 1.3.2 setuptools 59.5.0 six 1.16.0 smmap 5.0.0 snowballstemmer 2.2.0 Sphinx 7.0.1 sphinxcontrib-applehelp 1.0.4 sphinxcontrib-devhelp 1.0.2 sphinxcontrib-htmlhelp 2.0.1 sphinxcontrib-jsmath 1.0.1 sphinxcontrib-qthelp 1.0.3 sphinxcontrib-serializinghtml 1.1.5 stack-data 0.6.2 tensorboard 2.13.0 tensorboard-data-server 0.7.1 threadpoolctl 3.2.0 tomli 2.0.1 torch 1.12.1 torchaudio 0.12.1 torchmetrics 1.0.1 torchvision 0.13.1 tqdm 4.65.0 traitlets 5.9.0 typer 0.9.0 typing_extensions 4.7.1 tzdata 2023.3 urllib3 1.26.16 wandb 0.15.5 wcwidth 0.2.6 Werkzeug 2.3.6 wheel 0.38.4 win-inet-pton 1.1.0 yarl 1.9.2 zipp 3.16.2
能看一下运行的脚本吗?这个报错像是模型的参数没有给到
yaml: accelerator: null accumulate_grad_batches: null adv_temp: 1.0 amp_backend: native amp_level: null auto_lr_find: false auto_scale_batch_size: false auto_select_gpus: false axiom_types: 10 axiom_weight: 1.0 benchmark: false bern_flag: false calc_hits:
main.py :
import pytorch_lightning as pl from pytorch_lightning import seed_everything from IPython import embed import wandb from neuralkg.utils import setup_parser from neuralkg.utils.tools import from neuralkg.data.Sampler import from neuralkg.data.Grounding import GroundAllRules
def main(): parser = setup_parser() #设置参数 args = parser.parse_args() if args.load_config: args = load_config(args, args.config_path) seed_everything(args.seed) """set up sampler to datapreprocess""" #设置数据处理的采样过程 train_sampler_class = import_class(f"neuralkg.data.{args.train_sampler_class}") train_sampler = train_sampler_class(args) # 这个sampler是可选择的
test_sampler_class = import_class(f"neuralkg.data.{args.test_sampler_class}")
test_sampler = test_sampler_class(train_sampler) # test_sampler是一定要的
"""set up datamodule""" #设置数据模块
data_class = import_class(f"neuralkg.data.{args.data_class}") #定义数据类 DataClass
kgdata = data_class(args, train_sampler, test_sampler)
"""set up model"""
model_class = import_class(f"neuralkg.model.{args.model_name}")
if args.model_name == "RugE":
ground = GroundAllRules(args)
ground.PropositionalizeRule()
if args.model_name == "ComplEx_NNE_AER":
model = model_class(args, train_sampler.rel2id)
elif args.model_name == "IterE":
print(f"data.{args.train_sampler_class}")
model = model_class(args, train_sampler, test_sampler)
else:
model = model_class(args)
if args.model_name == 'SEGNN':
src_list = train_sampler.get_train_1.src_list
dst_list = train_sampler.get_train_1.dst_list
rel_list = train_sampler.get_train_1.rel_list
"""set up lit_model"""
litmodel_class = import_class(f"neuralkg.lit_model.{args.litmodel_name}")
if args.model_name =='SEGNN':
lit_model = litmodel_class(model, args, src_list, dst_list, rel_list)
else:
lit_model = litmodel_class(model, args)
"""set up logger"""
logger = pl.loggers.TensorBoardLogger("training/logs")
if args.use_wandb:
log_name = "_".join([args.model_name, args.dataset_name, str(args.lr)])
logger = pl.loggers.WandbLogger(name=log_name, project="NeuralKG")
logger.log_hyperparams(vars(args))
"""early stopping"""
early_callback = pl.callbacks.EarlyStopping(
monitor="Eval|mrr",
mode="max",
patience=args.early_stop_patience,
# verbose=True,
check_on_train_epoch_end=False,
)
"""set up model save method"""
# 目前是保存在验证集上mrr结果最好的模型
# 模型保存的路径
dirpath = "/".join(["output", args.eval_task, args.dataset_name, args.model_name])
model_checkpoint = pl.callbacks.ModelCheckpoint(
monitor="Eval|mrr",
mode="max",
filename="{epoch}-{Eval|mrr:.3f}",
dirpath=dirpath,
save_weights_only=True,
save_top_k=1,
)
callbacks = [early_callback, model_checkpoint]
# initialize trainer
if args.model_name == "IterE":
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=callbacks,
logger=logger,
default_root_dir="training/logs",
gpus=0,
check_val_every_n_epoch=args.check_per_epoch,
reload_dataloaders_every_n_epochs=1 # IterE
)
else:
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=callbacks,
logger=logger,
default_root_dir="training/logs",
gpus=0,
check_val_every_n_epoch=args.check_per_epoch,
)
'''保存参数到config'''
if args.save_config:
save_config(args)
if args.use_wandb:
logger.watch(lit_model)
if not args.test_only:
# train&valid
trainer.fit(lit_model, datamodule=kgdata)
# 加载本次实验中dev上表现最好的模型,进行test
path = model_checkpoint.best_model_path
else:
path = args.checkpoint_dir
lit_model.load_state_dict(torch.load(path)["state_dict"])
lit_model.eval()
trainer.test(lit_model, datamodule=kgdata)
if name == "main": main() 万分感谢!!!
看起来超参设置也没什么问题,我也不太清楚为什么会报这个错,建议可以先用gpu看能不能跑起来,如果gpu爆显存了可以把batch_size调小一点。
感谢回复,我去试试
换了一张3090之后这个问题解决了,但是出现新的问题:Traceback (most recent call last):
File "main.py", line 121, in
这个的train_bs,eval_bs是多大呢?
train_bs: 128 eval_bs: 16
要不再把train_bs调小试一试?这个报错很奇怪,源码是用len(data)
初始化的,可是顺序的index
却溢出了。
感谢回复,查了下发现如果显存不够也可能出现这个错误,我再试试
你好,你解决了吗?我运行demo.py时候也出现了这个问题
错误信息如题。跑的图谱有点大,手上仅有的一块3080爆显存了。于是禁用了gpu,然后出现了这个错误。这个错误是和禁用gpu有关吗还是什么其他的原因,请问需要如何解决。