Open wangsale opened 1 year ago
import torch from models import pointnet2_part_seg_msg
def to_categorical(y, num_classes): """ 1-hot encodes a tensor """ new_y = torch.eye(num_classes)[y.cpu().data.numpy(),] if (y.is_cuda): return new_y.cuda() return new_y
model = pointnet2_part_seg_msg.get_model(4, False)
model.eval() z=torch.load('b2.pth') model.load_state_dict(z)
example=torch.rand(1, 3, 2048)
label=torch.rand(1, 1)
traced_script_module = torch.jit.trace(model, (example, to_categorical(label, 1))) traced_script_module.save("b2.pt")
import torch from models import pointnet2_part_seg_msg
def to_categorical(y, num_classes): """ 1-hot encodes a tensor """ new_y = torch.eye(num_classes)[y.cpu().data.numpy(),] if (y.is_cuda): return new_y.cuda() return new_y
model = pointnet2_part_seg_msg.get_model(4, False)
model.eval() z=torch.load('b2.pth') model.load_state_dict(z)
example=torch.rand(1, 3, 2048)
label=torch.rand(1, 1)
traced_script_module = torch.jit.trace(model, (example, to_categorical(label, 1))) traced_script_module.save("b2.pt")