Closed ThomasVercetti closed 1 year ago
Hi, thanks for your interest in our work!
Firstly, please add forward_features
function to models/p2p.py
.
Class P2P(nn.Module):
...
def forward_features(self, pc, original_pc):
img = self.enc(original_pc, pc)
feat = self.base_model.forward_features(img)
return feat
Then you can inference representation features from a given point cloud via the demo file below. In the attachment is a zipped demo point cloud demo_sample.npy
. Hopefully it could solve your concern.
import os
import sys
import numpy as np
from typing import OrderedDict
import torch
sys.path.append(os.path.abspath(os.getcwd()))
from models.p2p import P2P
from util import config
from util.rotate import rotate_theta_phi, rotate_point_clouds
SAMPLE_PATH = 'demo_sample.npy' # path to point cloud sample
CONFIG_PATH = 'config/ModelNet40/p2p_ConvNeXt-T-1k.yaml' # path to model config
CHECKPOINT_PATH = 'pretrained/reproduce/ckpt/ModelNet40/ConvNeXt-T-1k-ModelNet40.pth' # path to pre-trained weights
# Load point cloud data and config
sample = torch.from_numpy(np.load(SAMPLE_PATH)).unsqueeze(dim=0).float().cuda()
cfg = config.load_cfg_from_cfg_file(CONFIG_PATH)
# Construct model and load pre-trained weights
model = P2P(cfg, is_test=True).cuda()
checkpoint = torch.load(CHECKPOINT_PATH, map_location=lambda storage, loc: storage.cuda())
state_dict = OrderedDict({key.replace("module.", ""): value for key, value in checkpoint['state_dict'].items()})
model.load_state_dict(state_dict, strict=True)
model.eval()
# Optional: rotate input point cloud
theta = 0.5
phi = -0.3
v_theta, v_phi = np.meshgrid(theta, phi)
angles = np.stack([v_theta, v_phi], axis=-1).reshape(-1, 2)
angles = torch.from_numpy(angles) * torch.pi
rotation_matrix = rotate_theta_phi(angles)[0]
# Forward pass to get features
input_pc = rotate_point_clouds(sample, rotation_matrix)
with torch.no_grad():
feat = model.forward_features(input_pc, original_pc = sample)
output = model(input_pc, original_pc = sample)
pred = torch.argmax(output, 1).detach().cpu().numpy()
print('Prediction class:', pred)
@LavenderLA Really appreciate that! It helps.
Hello,@wangzy22 thanks for sharing the model. I'm confused if the model can inferencing the point cloud and output the low-dimensional representation.