yanx27 / Pointnet_Pointnet2_pytorch

PointNet and PointNet++ implemented by pytorch (pure python) and on ModelNet, ShapeNet and S3DIS.
MIT License
3.75k stars 904 forks source link

pth 转 pt 时会失败 #221

Open wangsale opened 1 year ago

wangsale commented 1 year ago

import torch from models import pointnet2_part_seg_msg

def to_categorical(y, num_classes): """ 1-hot encodes a tensor """ new_y = torch.eye(num_classes)[y.cpu().data.numpy(),] if (y.is_cuda): return new_y.cuda() return new_y

model = pointnet2_part_seg_msg.get_model(4, False)

model.eval() z=torch.load('b2.pth') model.load_state_dict(z)

example=torch.rand(1, 3, 2048)

label=torch.rand(1, 1)

traced_script_module = torch.jit.trace(model, (example, to_categorical(label, 1))) traced_script_module.save("b2.pt")

wangsale commented 1 year ago

image