Open zhoujialos opened 3 months ago
hi, can you specify your issue?
I was hoping that the training result would be similar to the cppg waveform you published in your article, as shown below, and in fact I got the waveform shown in the above part of the graph after 3000 sessions of training.
Can you open up your weights to me?
Hi, I will upload the weights in the following days. Your first figure seems to show one periodic segment. if you get the results from joint_inference.py
, please remember the results contains results = {'rppg_list': rppg_list, 'bvp_list': bvp_list, 'bvp_cyc_list':bvp_cyc_list, 'cyc_list': cyc_list, 'pred_list': pred_list}
where rppg_list
contains rppg signals while cyc_list
contains the periodic segments of the rPPG signals.
The training data, weights, and results have been uploaded. Please check readme for more details.
import sys import numpy as np import h5py import torch import torch.nn as nn import matplotlib.pyplot as plt from rppg_model import rppg_model from biometric_models import from cycle_cut import cycle_cut from utils_data import from utils_sig import * from sacred import Experiment from sacred.observers import FileStorageObserver import json import os
e = 2900 # The model checkpoint at epoch e train_exp_name = 'default' train_exp_num = 2 # The training experiment number train_exp_dir = '/home/project/rppg_biometrics/joint_results/%s/%d' % (train_exp_name, train_exp_num) # Training experiment directory
if torch.cuda.is_available(): device = torch.device('cuda') torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True else: device = torch.device('cpu')
test_list = list(np.load(train_exp_dir + '/test_list.npy')) pred_exp_dir = os.path.join(train_exp_dir) # Prediction experiment directory os.makedirs(pred_exp_dir, exist_ok=True)
with open(train_exp_dir + '/config.json') as f: config_train = json.load(f)
model = rppg_model(config_train['fs']).to(device).eval() model.load_state_dict(torch.load(train_exp_dir + '/epoch%d_model.pt' % (e), map_location=device)) # Load weights to the model ppg_model = ppg_transformer(config_train['num_classes_old']).to(device).eval() ppg_model.load_state_dict(torch.load(train_exp_dir + '/epoch%d_ppg_model.pt' % (e), map_location=device)) cls_head = nn.Linear(64, config_train['num_classes']).to(device).eval() cls_head.load_state_dict(torch.load(train_exp_dir + '/epoch%d_cls_head.pt' % (e), map_location=device))
@torch.no_grad() def dl_model(imgs_clip, fs):
img_batch = imgs_clip.transpose((3, 0, 1, 2))
# Permutation
T = img_batch.shape[1]
hw = img_batch.shape[2]
img_batch = img_batch.reshape(3, T, -1)
img_batch = img_batch[:, :, np.random.permutation(hw * hw)]
img_batch = np.transpose(img_batch, (0, 2, 1)) # Shape (3, N, T)
img_batch = img_batch[np.newaxis].astype('float32')
img_batch = torch.tensor(img_batch).to(device)
_, rppg = model(img_batch)
rppg = config_train['reverse'] * rppg
cycle_list = cycle_cut(rppg, fs, length=90) # Cycle
cycles = torch.cat(cycle_list, 0)
_, cycle_f = ppg_model(cycles)
pred_cls = cls_head(cycle_f)
return rppg[0].detach().cpu().numpy(), cycles.detach().cpu().numpy(), cycle_f.detach().cpu().numpy()
for h5_path in test_list: h5_path = str(h5_path)
with h5py.File(h5_path, 'r') as f:
imgs = f['imgs'][:]
fs = config_train['fs']
img_length = imgs.shape[0]
rppg_list = []
cyc_list = []
cycle_f_list = []
pred_list = []
for b in range(1): # Loop over batches if needed
rppg_sig, cyc, cycle_f = dl_model(imgs[:img_length], fs)
rppg_list.append(rppg_sig)
cyc_list.append(cyc)
cycle_f_list.append(cycle_f)
# 这里不需要pred,因为你只想绘制cycle_f和rppg信号
rppg_list = np.array(rppg_list)
cyc_list = np.array(cyc_list)
cycle_f_list = np.array(cycle_f_list)
# 选择前1000个样本进行绘制
num_samples_to_plot = 500
rppg_plot_data = rppg_list.flatten()[:num_samples_to_plot]
cycle_f_plot_data = cycle_f_list.flatten()[:num_samples_to_plot]
# 绘制信号
plt.figure(figsize=(15, 5))
plt.subplot(2, 1, 1)
plt.plot(rppg_plot_data, label='rPPG Signal')
plt.legend()
plt.title('rPPG Signal')
plt.subplot(2, 1, 2)
plt.plot(cycle_f_plot_data, label='Cycle Feature Signal')
plt.legend()
plt.title('Cycle Feature Signal')
plt.tight_layout()
plt.show()
This is my code by your trainnig weight. I would like to learn the cppg of prediction. Is cycle_f_plot_data? In the below, this is the wave of prediction.
hi, the predicted rPPG signal is the first row in your figure. The cycle_f
contains the feature for each periodic segment and not the rPPG signal.