NVIDIA / DeepLearningExamples

State-of-the-Art Deep Learning scripts organized by models - easy to train and deploy with reproducible accuracy and performance on enterprise-grade infrastructure.
13.31k stars 3.19k forks source link

inference.py in DeepLearningExamples/PyTorch/Forecasting/TFT/ #1393

Open tdktrang opened 4 months ago

tdktrang commented 4 months ago

Related to DeepLearningExamples/PyTorch/Forecasting/TFT/ (e.g. GNMT/PyTorch or FasterTransformer/All)

Describe the bug When I run the "inference.py" the error happen because "unscaled_predictions" was numpy.ndarray. Therefore, we need to add the code to process the unscaled_predictions to tensor

To Reproduce Steps to reproduce the behavior:

python inference.py \
--checkpoint /results/TFT_electricity_bs8x1024_lr1e-3/seed_1/checkpoint.pt \
--data /data/processed/electricity_bin/test.csv \
--tgt_scalers /data/processed/electricity_bin/tgt_scalers.bin \
--cat_encodings /data/processed/electricity_bin/cat_encodings.bin \
--visualize \
--save_predictions

Expected behavior 'numpy.ndarray' object has no attribute 'new_full'

Environment