NVIDIA-AI-IOT / trt_pose

Real-time pose estimation accelerated with NVIDIA TensorRT
MIT License
974 stars 291 forks source link

How did you generate other model weights? #156

Open agrija9 opened 2 years ago

agrija9 commented 2 years ago

Hello @jaybdub, I have been able to run the live_demo.ipynb on a Jetson Xavier NX with the two provided models (resnet18 and densenet121).

However, I need better accuracy for my application. When I try to run the script with e.g. resnet50, I have an _IncompatibleKeys(missing_keys, unexpected_keys) error.

The way I am downloading the model and trying to load the weights is the following:

import json
import trt_pose.coco
import torch
import trt_pose.models
import torch2trt
from torch2trt import TRTModule
import time
import cv2
import torchvision.transforms as transforms
import PIL.Image

with open('human_pose.json', 'r') as f:
    human_pose = json.load(f)

topology = trt_pose.coco.coco_category_to_topology(human_pose)

num_parts = len(human_pose['keypoints'])
num_links = len(human_pose['skeleton'])

# Downloads model into /home/jetson3/.cache//torch/hub/checkpoints
model = trt_pose.models.resnet50_baseline_att(num_parts, 2 * num_links).cuda().eval()

# Load model weights
MODEL_WEIGHTS = "/home/jetson3/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth"
model.load_state_dict(torch.load(MODEL_WEIGHTS))

This is when the error happens:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-3-1687e7d09b26> in <module>
      7 MODEL_WEIGHTS = "/home/jetson3/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth"
----> 8 model.load_state_dict(torch.load(MODEL_WEIGHTS))

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1481         if len(error_msgs) > 0:
   1482             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1483                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1484         return _IncompatibleKeys(missing_keys, unexpected_keys)
   1485 

RuntimeError: Error(s) in loading state_dict for Sequential:

Appreciate your help!