GestaltCogTeam / BasicTS

A Fair and Scalable Time Series Forecasting Benchmark and Toolkit.
Apache License 2.0
626 stars 109 forks source link

请问如何使用inference程序,需要改哪些部分 #103

Closed KDP-wayofdata closed 9 months ago

KDP-wayofdata commented 10 months ago

请问如何使用inference程序,需要改哪些部分

KDP-wayofdata commented 10 months ago

错误如下 2024-01-02 10:20:32,580 - easytorch-launcher - INFO - Launching EasyTorch runner. Traceback (most recent call last): File "D:\BasicTS-master\experiments\inference.py", line 44, in launch_runner(cfg_path, inference, (ckpt_path, args.batch_size), devices=args.gpus) File "D:\BasicTS-master\basicts\launcher.py", line 10, in launch_runner easytorch.launch_runner(cfg=cfg, fn=fn, args=args, device_type=device_type, devices=devices) File "D:\Anaconda\envs\BasicTS\lib\site-packages\easytorch\launcher\launcher.py", line 105, in launch_runner cfg = init_cfg(cfg, True) File "D:\Anaconda\envs\BasicTS\lib\site-packages\easytorch\config\utils.py", line 210, in init_cfg cfg = import_config(cfg, verbose=save) File "D:\Anaconda\envs\BasicTS\lib\site-packages\easytorch\config\utils.py", line 173, in import_config cfg = import(path, fromlist=[cfg_name]).CFG ModuleNotFoundError: No module named 'checkpoints.D2STGNN_100.0b961d1ea97a431f05300b81816c58f4.cfg'

进程已结束,退出代码1

zezhishao commented 10 months ago

请您提供一下您的代码。

KDP-wayofdata commented 10 months ago

请您提供一下您的代码。

import os import sys import time sys.path.append(os.path.abspath(file + '/../..')) from argparse import ArgumentParser

from basicts import launch_runner, BaseRunner

def inference(cfg: dict, runner: BaseRunner, ckpt: str = None, batch_size: int = 1):

init logger

runner.init_logger(logger_name='easytorch-inference', log_file_name='validate_result')
# init model
cfg.TEST.DATA.BATCH_SIZE = batch_size
runner.model.eval()
runner.setup_graph(cfg=cfg, train=False)
# load model checkpoint
runner.load_model(ckpt_path=ckpt)
# inference & speed
t0 = time.perf_counter()
runner.test_process(cfg)
elapsed = time.perf_counter() - t0

print('##############################')
runner.logger.info('%s: %0.8fs' % ('Speed', elapsed))
runner.logger.info('# Param: {0}'.format(sum(p.numel() for p in runner.model.parameters() if p.requires_grad)))

if name == 'main': MODEL_NAME = 'D2STGNN' DATASET_NAME = 'METR-LA' BATCH_SIZE = 32 GPUS = '2'

parser = ArgumentParser(description='Welcome to EasyTorch!')
parser.add_argument('-m', '--model', default=MODEL_NAME, help='model name')
parser.add_argument('-d', '--dataset', default=DATASET_NAME, help='dataset name')
parser.add_argument('-g', '--gpus', default=GPUS, help='visible gpus')
parser.add_argument('-b', '--batch_size', default=BATCH_SIZE, type=int, help='batch size')
args = parser.parse_args()

cfg_path = 'checkpoints/D2STGNN_100/0b961d1ea97a431f05300b81816c58f4/cfg.py'.format(args.model, args.dataset)
ckpt_path = 'checkpoints/D2STGNN_100/0b961d1ea97a431f05300b81816c58f4/D2STGNN_best_val_MAE.pt'.format(args.model, args.dataset)

launch_runner(cfg_path, inference, (ckpt_path, args.batch_size), devices=args.gpus)
zezhishao commented 10 months ago

cfg_patch需要使用baselines文件夹下的cfg。checkpoints文件夹下的cfg只是一个存档作用。

KDP-wayofdata commented 10 months ago

cfg_patch需要使用baselines文件夹下的cfg。checkpoints文件夹下的cfg只是一个存档作用。

是需要把checkpoints文件夹下的cfg转存放到baselines文件夹下嘛

zezhishao commented 10 months ago

不是的,直接使用baselines下的cfg就可以了。 比如cfg_path = 'baselines/{0}/{1}.py'.format("AGCRN", "PEMS04")。 checkpoints文件夹下的cfg只是一个存档作用,它不是一个packages也无法正确地引用其他的package。

KDP-wayofdata commented 10 months ago

不是的,直接使用baselines下的cfg就可以了。 比如cfg_path = 'baselines/{0}/{1}.py'.format("AGCRN", "PEMS04")。 checkpoints文件夹下的cfg只是一个存档作用,它不是一个packages也无法正确地引用其他的package。 我的baselines下没有cfg是什么情况呀

zezhishao commented 10 months ago

不会呀,比如baselines/AGCRN/PEMS07.py

KDP-wayofdata commented 10 months ago

麻烦问一下,以下代码是我哪里没改对嘛,我核对了好几遍总是说No such file or directory: 'datasets/METR-LA/adj_mx.pkl',他需要的这个文件确实是放在该路径下了,谢谢 import os import sys import time sys.path.append(os.path.abspath(file + '/../..')) from argparse import ArgumentParser

from basicts import launch_runner, BaseRunner

def inference(cfg: dict, runner: BaseRunner, ckpt: str = None, batch_size: int = 1):

init logger

runner.init_logger(logger_name='easytorch-inference', log_file_name='validate_result')
# init model
cfg.TEST.DATA.BATCH_SIZE = batch_size
runner.model.eval()
runner.setup_graph(cfg=cfg, train=False)
# load model checkpoint
runner.load_model(ckpt_path=ckpt)
# inference & speed
t0 = time.perf_counter()
runner.test_process(cfg)
elapsed = time.perf_counter() - t0

print('##############################')
runner.logger.info('%s: %0.8fs' % ('Speed', elapsed))
runner.logger.info('# Param: {0}'.format(sum(p.numel() for p in runner.model.parameters() if p.requires_grad)))

if name == 'main': MODEL_NAME = 'D2STGNN' DATASET_NAME = 'METR-LA' BATCH_SIZE = 32 GPUS = '2'

parser = ArgumentParser(description='Welcome to EasyTorch!')
parser.add_argument('-m', '--model', default=MODEL_NAME, help='model name')
parser.add_argument('-d', '--dataset', default=DATASET_NAME, help='dataset name')
parser.add_argument('-g', '--gpus', default=GPUS, help='visible gpus')
parser.add_argument('-b', '--batch_size', default=BATCH_SIZE, type=int, help='batch size')
args = parser.parse_args()

cfg_path = 'baselines/D2STGNN/D2STGNN_METR-LA.py'.format("D2STGNN", "METR-LA")
ckpt_path = 'checkpoints/D2STGNN_100/0b961d1ea97a431f05300b81816c58f4/D2STGNN_best_val_MAE.pt'.format("D2STGNN", "METR-LA")

launch_runner(cfg_path, inference, (ckpt_path, args.batch_size), devices=args.gpus)
zezhishao commented 10 months ago

你在哪个文件夹下运行的代码,以什么样的方式运行的代码?

KDP-wayofdata commented 10 months ago

我在experiments文件里运行的,直接在inference里右键运行的,有专属的终端运行指令嘛?

zezhishao commented 10 months ago

不要在experiments下面跑哈,所有的命令都是在project目录下面跑。就是experiments的上一级。

KDP-wayofdata commented 10 months ago

请问有运行指令嘛,或者相应的操作方法,感谢!

zezhishao commented 10 months ago

参考:

cd /path/to/BasicTS # 就是experiments上级文件夹,也就是你clone下来的文件夹
python experiments/inference.py

这个项目的所有命令都得在BasicTS根目录下运行。 我一般都是用VSCode,PyCharm可能会有路径问题。

KDP-wayofdata commented 10 months ago

可以啦,谢谢