tteepe / GaitGraph

Official repository for "GaitGraph: Graph Convolutional Network for Skeleton-Based Gait Recognition" (ICIP'21)
https://arxiv.org/abs/2101.11228
MIT License
91 stars 27 forks source link

Making an Inference Script #32

Open buckeye17 opened 1 year ago

buckeye17 commented 1 year ago

I'm attempting to adapt evaluate.py into my own infer.py. Currently I'm just trying to flush a dummy tensor through but I'm getting an error. Any help would be appreciated!!!

Here's my infer.py:

import numpy as np
import torch

from common import get_model_resgcn
from datasets.graph import Graph

# define configuration options with custom class instead of using CLI & argument parser
class opt():
    weights_path = "/app/Gait/GaitGraph/src/models/gaitgraph_resgcn-n39-r8_coco_seq_60.pth"
    network_name = "resgcn-n39-r8"
    embedding_layer_size = 128
    temporal_kernel_size = 9
    dropout = 0.4
    use_multi_branch = True

# Config for dataset
graph = Graph("coco")

# Init model
model, model_args = get_model_resgcn(graph, opt)

if torch.cuda.is_available():
    model.cuda()

# Load weights
checkpoint = torch.load(opt.weights_path)
model.load_state_dict(checkpoint["model"])

model.eval()

# Load data
data_arr = np.ones((51,60,1))
data_ten = torch.from_numpy(data_arr)

# Calculate embeddings
with torch.no_grad():
    bsz = data_ten.shape[0]
    data_ten_flipped = torch.flip(data_ten, dims=[1])
    data_ten = torch.cat([data_ten, data_ten_flipped], dim=0)

    if torch.cuda.is_available():
        data_ten = data_ten.cuda(non_blocking=True)

    output = model(data_ten)
    f1, f2 = torch.split(output, [bsz, bsz], dim=0)
    output = torch.mean(torch.stack([f1, f2]), dim=0)

Here's the error message I get:

Traceback (most recent call last):
  File "/app/Gait/GaitGraph/src/infer.py", line 27, in <module>
    model.load_state_dict(checkpoint["model"])
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for ResGCN:
        Missing key(s) in state_dict: "input_branches.1.A", "input_branches.1.bn.weight", "input_branches.1.bn.bias", "input_branches.1.bn.running_mean", "input_branches.1.bn.running_var", "input_branches.1.layers.0.edge", "input_branches.1.layers.0.scn.conv.gcn.weight", "input_branches.1.layers.0.scn.conv.gcn.bias", "input_branches.1.layers.0.scn.bn.weight", "input_branches.1.layers.0.scn.bn.bias", "input_branches.1.layers.0.scn.bn.running_mean", "input_branches.1.layers.0.scn.bn.running_var", "input_branches.1.layers.0.tcn.conv.weight", "input_branches.1.layers.0.tcn.conv.bias", "input_branches.1.layers.0.tcn.bn.weight", "input_branches.1.layers.0.tcn.bn.bias", "input_branches.1.layers.0.tcn.bn.running_mean", "input_branches.1.layers.0.tcn.bn.running_var", "input_branches.1.layers.1.edge", "input_branches.1.layers.1.scn.conv_down.weight", "input_branches.1.layers.1.scn.conv_down.bias", "input_branches.1.layers.1.scn.bn_down.weight", "input_branches.1.layers.1.scn.bn_down.bias", "input_branches.1.layers.1.scn.bn_down.running_mean", "input_branches.1.layers.1.scn.bn_down.running_var", "input_branches.1.layers.1.scn.conv.gcn.weight", "input_branches.1.layers.1.scn.conv.gcn.bias", "input_branches.1.layers.1.scn.bn.weight", "input_branches.1.layers.1.scn.bn.bias", "input_branches.1.layers.1.scn.bn.running_mean", "input_branches.1.layers.1.scn.bn.running_var", "input_branches.1.layers.1.scn.conv_up.weight", "input_branches.1.layers.1.scn.conv_up.bias", "input_branches.1.layers.1.scn.bn_up.weight", "input_branches.1.layers.1.scn.bn_up.bias", "input_branches.1.layers.1.scn.bn_up.running_mean", "input_branches.1.layers.1.scn.bn_up.running_var", "input_branches.1.layers.1.tcn.conv_down.weight", "input_branches.1.layers.1.tcn.conv_down.bias", "input_branches.1.layers.1.tcn.bn_down.weight", "input_branches.1.layers.1.tcn.bn_down.bias", "input_branches.1.layers.1.tcn.bn_down.running_mean", "input_branches.1.layers.1.tcn.bn_down.running_var", "input_branches.1.layers.1.tcn.conv.weight", "input_branches.1.layers.1.tcn.conv.bias", "input_branches.1.layers.1.tcn.bn.weight", "input_branches.1.layers.1.tcn.bn.bias", "input_branches.1.layers.1.tcn.bn.running_mean", "input_branches.1.layers.1.tcn.bn.running_var", "input_branches.1.layers.1.tcn.conv_up.weight", "input_branches.1.layers.1.tcn.conv_up.bias", "input_branches.1.layers.1.tcn.bn_up.weight", "input_branches.1.layers.1.tcn.bn_up.bias", "input_branches.1.layers.1.tcn.bn_up.running_mean", "input_branches.1.layers.1.tcn.bn_up.running_var", "input_branches.1.layers.2.edge", "input_branches.1.layers.2.scn.residual.0.weight", "input_branches.1.layers.2.scn.residual.0.bias", "input_branches.1.layers.2.scn.residual.1.weight", "input_branches.1.layers.2.scn.residual.1.bias", "input_branches.1.layers.2.scn.residual.1.running_mean", "input_branches.1.layers.2.scn.residual.1.running_var", "input_branches.1.layers.2.scn.conv_down.weight", "input_branches.1.layers.2.scn.conv_down.bias", "input_branches.1.layers.2.scn.bn_down.weight", "input_branches.1.layers.2.scn.bn_down.bias", "input_branches.1.layers.2.scn.bn_down.running_mean", "input_branches.1.layers.2.scn.bn_down.running_var", "input_branches.1.layers.2.scn.conv.gcn.weight", "input_branches.1.layers.2.scn.conv.gcn.bias", "input_branches.1.layers.2.scn.bn.weight", "input_branches.1.layers.2.scn.bn.bias", "input_branches.1.layers.2.scn.bn.running_mean", "input_branches.1.layers.2.scn.bn.running_var", "input_branches.1.layers.2.scn.conv_up.weight", "input_branches.1.layers.2.scn.conv_up.bias", "input_branches.1.layers.2.scn.bn_up.weight", "input_branches.1.layers.2.scn.bn_up.bias", "input_branches.1.layers.2.scn.bn_up.running_mean", "input_branches.1.layers.2.scn.bn_up.running_var", "input_branches.1.layers.2.tcn.conv_down.weight", "input_branches.1.layers.2.tcn.conv_down.bias", "input_branches.1.layers.2.tcn.bn_down.weight", "input_branches.1.layers.2.tcn.bn_down.bias", "input_branches.1.layers.2.tcn.bn_down.running_mean", "input_branches.1.layers.2.tcn.bn_down.running_var", "input_branches.1.layers.2.tcn.conv.weight", "input_branches.1.layers.2.tcn.conv.bias", "input_branches.1.layers.2.tcn.bn.weight", "input_branches.1.layers.2.tcn.bn.bias", "input_branches.1.layers.2.tcn.bn.running_mean", "input_branches.1.layers.2.tcn.bn.running_var", "input_branches.1.layers.2.tcn.conv_up.weight", "input_branches.1.layers.2.tcn.conv_up.bias", "input_branches.1.layers.2.tcn.bn_up.weight", "input_branches.1.layers.2.tcn.bn_up.bias", "input_branches.1.layers.2.tcn.bn_up.running_mean", "input_branches.1.layers.2.tcn.bn_up.running_var", "input_branches.2.A", "input_branches.2.bn.weight", "input_branches.2.bn.bias", "input_branches.2.bn.running_mean", "input_branches.2.bn.running_var", "input_branches.2.layers.0.edge", "input_branches.2.layers.0.scn.conv.gcn.weight", "input_branches.2.layers.0.scn.conv.gcn.bias", "input_branches.2.layers.0.scn.bn.weight", "input_branches.2.layers.0.scn.bn.bias", "input_branches.2.layers.0.scn.bn.running_mean", "input_branches.2.layers.0.scn.bn.running_var", "input_branches.2.layers.0.tcn.conv.weight", "input_branches.2.layers.0.tcn.conv.bias", "input_branches.2.layers.0.tcn.bn.weight", "input_branches.2.layers.0.tcn.bn.bias", "input_branches.2.layers.0.tcn.bn.running_mean", "input_branches.2.layers.0.tcn.bn.running_var", "input_branches.2.layers.1.edge", "input_branches.2.layers.1.scn.conv_down.weight", "input_branches.2.layers.1.scn.conv_down.bias", "input_branches.2.layers.1.scn.bn_down.weight", "input_branches.2.layers.1.scn.bn_down.bias", "input_branches.2.layers.1.scn.bn_down.running_mean", "input_branches.2.layers.1.scn.bn_down.running_var", "input_branches.2.layers.1.scn.conv.gcn.weight", "input_branches.2.layers.1.scn.conv.gcn.bias", "input_branches.2.layers.1.scn.bn.weight", "input_branches.2.layers.1.scn.bn.bias", "input_branches.2.layers.1.scn.bn.running_mean", "input_branches.2.layers.1.scn.bn.running_var", "input_branches.2.layers.1.scn.conv_up.weight", "input_branches.2.layers.1.scn.conv_up.bias", "input_branches.2.layers.1.scn.bn_up.weight", "input_branches.2.layers.1.scn.bn_up.bias", "input_branches.2.layers.1.scn.bn_up.running_mean", "input_branches.2.layers.1.scn.bn_up.running_var", "input_branches.2.layers.1.tcn.conv_down.weight", "input_branches.2.layers.1.tcn.conv_down.bias", "input_branches.2.layers.1.tcn.bn_down.weight", "input_branches.2.layers.1.tcn.bn_down.bias", "input_branches.2.layers.1.tcn.bn_down.running_mean", "input_branches.2.layers.1.tcn.bn_down.running_var", "input_branches.2.layers.1.tcn.conv.weight", "input_branches.2.layers.1.tcn.conv.bias", "input_branches.2.layers.1.tcn.bn.weight", "input_branches.2.layers.1.tcn.bn.bias", "input_branches.2.layers.1.tcn.bn.running_mean", "input_branches.2.layers.1.tcn.bn.running_var", "input_branches.2.layers.1.tcn.conv_up.weight", "input_branches.2.layers.1.tcn.conv_up.bias", "input_branches.2.layers.1.tcn.bn_up.weight", "input_branches.2.layers.1.tcn.bn_up.bias", "input_branches.2.layers.1.tcn.bn_up.running_mean", "input_branches.2.layers.1.tcn.bn_up.running_var", "input_branches.2.layers.2.edge", "input_branches.2.layers.2.scn.residual.0.weight", "input_branches.2.layers.2.scn.residual.0.bias", "input_branches.2.layers.2.scn.residual.1.weight", "input_branches.2.layers.2.scn.residual.1.bias", "input_branches.2.layers.2.scn.residual.1.running_mean", "input_branches.2.layers.2.scn.residual.1.running_var", "input_branches.2.layers.2.scn.conv_down.weight", "input_branches.2.layers.2.scn.conv_down.bias", "input_branches.2.layers.2.scn.bn_down.weight", "input_branches.2.layers.2.scn.bn_down.bias", "input_branches.2.layers.2.scn.bn_down.running_mean", "input_branches.2.layers.2.scn.bn_down.running_var", "input_branches.2.layers.2.scn.conv.gcn.weight", "input_branches.2.layers.2.scn.conv.gcn.bias", "input_branches.2.layers.2.scn.bn.weight", "input_branches.2.layers.2.scn.bn.bias", "input_branches.2.layers.2.scn.bn.running_mean", "input_branches.2.layers.2.scn.bn.running_var", "input_branches.2.layers.2.scn.conv_up.weight", "input_branches.2.layers.2.scn.conv_up.bias", "input_branches.2.layers.2.scn.bn_up.weight", "input_branches.2.layers.2.scn.bn_up.bias", "input_branches.2.layers.2.scn.bn_up.running_mean", "input_branches.2.layers.2.scn.bn_up.running_var", "input_branches.2.layers.2.tcn.conv_down.weight", "input_branches.2.layers.2.tcn.conv_down.bias", "input_branches.2.layers.2.tcn.bn_down.weight", "input_branches.2.layers.2.tcn.bn_down.bias", "input_branches.2.layers.2.tcn.bn_down.running_mean", "input_branches.2.layers.2.tcn.bn_down.running_var", "input_branches.2.layers.2.tcn.conv.weight", "input_branches.2.layers.2.tcn.conv.bias", "input_branches.2.layers.2.tcn.bn.weight", "input_branches.2.layers.2.tcn.bn.bias", "input_branches.2.layers.2.tcn.bn.running_mean", "input_branches.2.layers.2.tcn.bn.running_var", "input_branches.2.layers.2.tcn.conv_up.weight", "input_branches.2.layers.2.tcn.conv_up.bias", "input_branches.2.layers.2.tcn.bn_up.weight", "input_branches.2.layers.2.tcn.bn_up.bias", "input_branches.2.layers.2.tcn.bn_up.running_mean", "input_branches.2.layers.2.tcn.bn_up.running_var". 
        size mismatch for input_branches.0.bn.weight: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([6]).
        size mismatch for input_branches.0.bn.bias: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([6]).
        size mismatch for input_branches.0.bn.running_mean: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([6]).
        size mismatch for input_branches.0.bn.running_var: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([6]).
        size mismatch for input_branches.0.layers.0.scn.conv.gcn.weight: copying a param with shape torch.Size([192, 3, 1, 1]) from checkpoint, the shape in current model is torch.Size([192, 6, 1, 1]).
        size mismatch for main_stream.0.scn.residual.0.weight: copying a param with shape torch.Size([128, 32, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 96, 1, 1]).
        size mismatch for main_stream.0.scn.conv_down.weight: copying a param with shape torch.Size([16, 32, 1, 1]) from checkpoint, the shape in current model is torch.Size([16, 96, 1, 1]).
buckeye17 commented 1 year ago

I eventually figured out my issues. Here's my final script to flush a dummy tensor through GaitGraph for inference:

import numpy as np
import torch

from common import get_model_resgcn
from datasets.graph import Graph

# define configuration options with custom class instead of using CLI & argument parser
class opt():
    weights_path = "/app/Gait/GaitGraph/models/gaitgraph_resgcn-n39-r8_coco_seq_60.pth"
    network_name = "resgcn-n39-r8"
    embedding_layer_size = 128
    temporal_kernel_size = 9
    dropout = 0.4
    use_multi_branch = False

# Config for dataset
graph = Graph("coco")

# Init model
model, model_args = get_model_resgcn(graph, opt)

if torch.cuda.is_available():
    model.cuda()

# Load weights
checkpoint = torch.load(opt.weights_path)
model.load_state_dict(checkpoint["model"], strict=False)

model.eval()

# Load data
# Dim 1: batch components
# Dim 2: unknown
# Dim 3: coordinates (x, y, conf)
# Dim 4: sequence of 60 frames
# Dim 5: coordinate list (nose, left_eye, right_eye, etc)
data_arr = np.ones((1,1,3,60,17)).astype(float)
data_ten = torch.from_numpy(data_arr)
data_ten = data_ten.type(torch.cuda.FloatTensor)

# Calculate embeddings
with torch.no_grad():
    bsz = data_ten.shape[0]
    data_ten_flipped = torch.flip(data_ten, dims=[1])
    data_ten = torch.cat([data_ten, data_ten_flipped], dim=0)

    if torch.cuda.is_available():
        data_ten = data_ten.cuda(non_blocking=True)

    output = model(data_ten)
    f1, f2 = torch.split(output, [bsz, bsz], dim=0)
    output = torch.mean(torch.stack([f1, f2]), dim=0)
weishiguan commented 8 months ago

请问一下,你最终是否顺利运行出了代码。我目前也遇到了一些问题,想知道代码是否可行。希望能得到回答,不甚感激