drprojects / superpoint_transformer

Official PyTorch implementation of Superpoint Transformer introduced in [ICCV'23] "Efficient 3D Semantic Segmentation with Superpoint Transformer" and SuperCluster introduced in [3DV'24 Oral] "Scalable 3D Panoptic Segmentation As Superpoint Graph Clustering"
MIT License
545 stars 71 forks source link

inference error #60

Closed bkabelik closed 7 months ago

bkabelik commented 7 months ago

Hi, I try to do inference on the dales dataset with your notebook code:

at model load step I get an error (also with your pretrained model): model = model.load_from_checkpoint(cfg.ckpt_path, net=model.net, criterion=model.criterion)

I get an error:

Exception has occurred: TypeError The classmethod PointSegmentationModule.load_from_checkpoint cannot be called on an instance. Please call it on the class type and make sure the return value is used. File "/home/daktar/asustor/01_Kabelik_gmbh/25_training/11_spt/01_infernce.py", line 46, in model = model.load_from_checkpoint(cfg.ckpt_path, net=model.net, criterion=model.criterion) TypeError: The classmethod PointSegmentationModule.load_from_checkpoint cannot be called on an instance. Please call it on the class type and make sure the return value is used.

Please help.

Here is my code (should be same as in your notebook):

import os import sys

Add the project's files to the python path

file_path = os.path.dirname(os.path.dirname(os.path.abspath(file))) # for .py script

file_path = os.path.dirname(os.path.abspath('')) # for .ipynb notebook

sys.path.append(file_path)

sys.path.append("/home/daktar/superpoint_transformer")

Necessary for advanced config parsing with hydra and omegaconf

from omegaconf import OmegaConf OmegaConf.register_new_resolver("eval", eval)

import hydra from src.utils import init_config import torch from src.visualization import show from src.datasets.dales import CLASS_NAMES, CLASS_COLORS from src.datasets.dales import DALES_NUM_CLASSES as NUM_CLASSES from src.transforms import *

cfg = init_config(overrides=[ "experiment=dales", "ckpt_path=/home/daktar/superpoint_transformer/logs/train/runs/2024-01-05_11-00-33/checkpoints/epoch_279.ckpt" ])

Instantiate the datamodule

datamodule = hydra.utils.instantiate(cfg.datamodule) datamodule.prepare_data() datamodule.setup()

Instantiate the model

model = hydra.utils.instantiate(cfg.model)

Load pretrained weights from a checkpoint file

model = model.load_from_checkpoint(cfg.ckpt_path, net=model.net, criterion=model.criterion)

Pick among train, val, and test datasets. It is important to note that

the train dataset produces augmented spherical samples of large

scenes, while the val and test dataset

dataset = datamodule.train_dataset

dataset = datamodule.val_dataset

dataset = datamodule.test_dataset

For the sake of visualization, we require that NAGAddKeysTo does not

remove input Data attributes after moving them to Data.x, so we may

visualize them

for t in dataset.on_device_transform.transforms: if isinstance(t, NAGAddKeysTo): t.delete_after = False

Load a dataset item. This will return the hierarchical partition of an

entire tile, within a NAG object

nag = dataset[0]

Apply on-device transforms on the NAG object. For the train dataset,

this will select a spherical sample of the larger tile and apply some

data augmentations. For the validation and test datasets, this will

prepare an entire tile for inference

nag = dataset.on_device_transform(nag.cuda())

Inference

logits = model(nag)

If the model outputs multi-stage predictions, we take the first one,

corresponding to level-1 predictions

if model.multi_stage_loss: logits = logits[0]

Compute the level-0 (pointwise) predictions based on the predictions

on level-1 superpoints

l1_preds = torch.argmax(logits, dim=1).detach() l0_preds = l1_preds[nag[0].super_index]

Save predictions for visualization in the level-0 Data attributes

nag[0].pred = l0_preds

Visualize the hierarchical partition

show( nag, class_names=CLASS_NAMES, ignore=NUM_CLASSES, class_colors=CLASS_COLORS, max_points=100000 )

Pick a center and radius for the spherical sample

center = torch.tensor([[40, 115, 0]]).to(nag.device) radius = 10

Create a mask on level-0 (ie points) to be used for indexing the NAG

structure

mask = torch.where(torch.linalg.norm(nag[0].pos - center, dim=1) < radius)[0]

Subselect the hierarchical partition based on the level-0 mask

nag_visu = nag.select(0, mask)

Visualize the sample

show( nag_visu, class_names=CLASS_NAMES, ignore=NUM_CLASSES, class_colors=CLASS_COLORS, max_points=100000 )