drprojects / superpoint_transformer

Official PyTorch implementation of Superpoint Transformer introduced in [ICCV'23] "Efficient 3D Semantic Segmentation with Superpoint Transformer" and SuperCluster introduced in [3DV'24 Oral] "Scalable 3D Panoptic Segmentation As Superpoint Graph Clustering"
MIT License
545 stars 71 forks source link

inference error #61

Closed bkabelik closed 7 months ago

bkabelik commented 7 months ago

Hi, I try to do inference on the dales dataset with your notebook code:

at model load step I get an error (also with your pretrained model): model = model.load_from_checkpoint(cfg.ckpt_path, net=model.net, criterion=model.criterion)

I get an error:

Exception has occurred: TypeError
The classmethod `PointSegmentationModule.load_from_checkpoint` cannot be called on an instance. Please call it on the class type and make sure the return value is used.
  File "/home/daktar/asustor/01_Kabelik_gmbh/25_training/11_spt/01_infernce.py", line 46, in <module>
    model = model.load_from_checkpoint(cfg.ckpt_path, net=model.net, criterion=model.criterion)
TypeError: The classmethod `PointSegmentationModule.load_from_checkpoint` cannot be called on an instance. Please call it on the class type and make sure the return value is used.

Please help.

Here is my code (should be same as in your notebook):

import os
import sys

sys.path.append("/home/daktar/superpoint_transformer")

from omegaconf import OmegaConf
OmegaConf.register_new_resolver("eval", eval)

import hydra
from src.utils import init_config
import torch
from src.visualization import show
from src.datasets.dales import CLASS_NAMES, CLASS_COLORS
from src.datasets.dales import DALES_NUM_CLASSES as NUM_CLASSES
from src.transforms import *

cfg = init_config(overrides=[
    "experiment=dales",
    "ckpt_path=/home/daktar/superpoint_transformer/logs/train/runs/2024-01-05_11-00-33/checkpoints/epoch_279.ckpt"
])

datamodule = hydra.utils.instantiate(cfg.datamodule)
datamodule.prepare_data()
datamodule.setup()

model = hydra.utils.instantiate(cfg.model)

model = model.load_from_checkpoint(cfg.ckpt_path, net=model.net, criterion=model.criterion)
model = model.eval().cuda()

dataset = datamodule.val_dataset

for t in dataset.on_device_transform.transforms:
    if isinstance(t, NAGAddKeysTo):
        t.delete_after = False

nag = dataset[0]

nag = dataset.on_device_transform(nag.cuda())

logits = model(nag)

if model.multi_stage_loss:
    logits = logits[0]

l1_preds = torch.argmax(logits, dim=1).detach()
l0_preds = l1_preds[nag[0].super_index]

nag[0].pred = l0_preds

show(
    nag, 
    class_names=CLASS_NAMES, 
    ignore=NUM_CLASSES,
    class_colors=CLASS_COLORS,
    max_points=100000
)
drprojects commented 7 months ago

Hi,

PS: if you are using this project, don't forget to give it a ⭐, it matters to us !

bkabelik commented 7 months ago

thank you very much! the link helped me a lot! I could fix my problem.

what I did is to import the PointSegmentationModule with: from src.models.segmentation import PointSegmentationModule

and changed the load checkpoint line to: model = PointSegmentationModule.load_from_checkpoint(cfg.ckpt_path, net=model.net, criterion=model.criterion)