drprojects / superpoint_transformer

Official PyTorch implementation of Superpoint Transformer introduced in [ICCV'23] "Efficient 3D Semantic Segmentation with Superpoint Transformer" and SuperCluster introduced in [3DV'24 Oral] "Scalable 3D Panoptic Segmentation As Superpoint Graph Clustering"
MIT License
546 stars 71 forks source link

How to train or validate over s3dis and s3disroom? #49

Closed gardiens closed 8 months ago

gardiens commented 8 months ago

Hello, Suppose we have two datamodule, for instance s3dis and s3disroom. I would like to train a model on s3dis and validate over s3dis and s3disroom to check how the model behave on a subsample of s3dis,

Do you know if their is an easy way to do so?

What I already tried:

fromtyping import Any, Dict
from lightning import LightningDataModule
from src.loader.dataloader import DataLoader
from src.data import NAGBatch
from lightning.pytorch.utilities.combined_loader import CombinedLoader

from lightning.pytorch.utilities.types import TRAIN_DATALOADERS
class ConcatenateDatamodule1(LightningDataModule):
    KEY_MAIN=0
    def __init__(self,l_datamodule,**kwargs):
        super().__init__()
        self.save_hyperparameters(logger=False)
        self.kwargs = kwargs

        assert len(l_datamodule)>0,"non"
        self.l_datamodule:BaseDataset=l_datamodule
        self.main_datamodule=self.l_datamodule[self.KEY_MAIN]
        self.train_dataset=[]
        self.val_dataset=[]
        self.test_dataset=[]
        self.pred_dataset=[]

        self.pre_transform = []
        self.train_transform = []
        self.val_transform = []
        self.test_transform = []
        self.on_device_train_transform = []
        self.on_device_val_transform = []
        self.on_device_test_transform = []
        self.on_device_pred_transform = []
        self.pred_transform=[]
        self.set_transforms()

    def set_transforms(self):
        for datamodule in self.l_datamodule:
            KEY=["pre_transform","train_transform","val_transform","test_transform","on_device_train_transform","on_device_val_transform","on_device_test_transform","on_device_pred_transform"]
            for key in KEY: 
                self.__getattribute__(key).append(datamodule.__getattribute__(key))

    @property
    def dataset_class(self):
        """Return the LightningDataModule's Dataset class.
        """
        return self.main_datamodule.dataset_class()
    @property
    def train_stage(self):
        return self.main_datamodule.train_stage()
    @property
    def val_stage(self):
        return self.main_datamodule.val_stage(

        )
    def prepare_data(self):
        for datamodule in self.l_datamodule:
            datamodule.prepare_data() #?
    def setup(self,stage:str):
        KEY=["train_dataset","val_dataset","test_dataset"]
        for datamodule in self.l_datamodule:
            datamodule.setup(stage=stage)
            for key in KEY:
                self.__getattribute__(key).append(datamodule.__getattribute__(key))

    def train_dataloader(self) -> TRAIN_DATALOADERS:
        l=[]
        for datamodule in self.l_datamodule:
            l.append(datamodule.train_dataloader())
        return l
    def val_dataloader(self) -> TRAIN_DATALOADERS:
        l=[]
        for datamodule in self.l_datamodule:
            l.append(datamodule.val_dataloader())
        return l
    def test_dataloader(self) -> TRAIN_DATALOADERS:
        l=[]
        for datamodule in self.l_datamodule:
            l.append(datamodule.test_dataloader())
        return l
    def predict_dataloader(self):
        # any iterable or collection of iterables
        l=[]
        for datamodule in self.l_datamodule:
            l.append(datamodule.predict_dataloader())
        return self.main_datamodule.predict_dataloader()

    def teardown(self, stage: str) -> None:
        pass 
    def state_dict(self):
        return {}
    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        pass 

    @torch.no_grad()
    def on_after_batch_transfer(self, nag_list, dataloader_idx):
        """Intended to call on-device operations. Typically,
        NAGBatch.from_nag_list and some Transforms like SampleSubNodes
        and SampleSegments are faster on GPU, and we may prefer
        executing those on GPU rather than in CPU-based DataLoader.

        Use self.on_device_<stage>_transform, to benefit from this hook.
        """
        # Since NAGBatch.from_nag_list takes a bit of time, we asked
        # src.loader.DataLoader to simply pass a list of NAG objects,
        # waiting for to be batched on device.
        nag = NAGBatch.from_nag_list(nag_list)
        del nag_list

        # Here we run on_device_transform, which contains NAG transforms
        # that we could not / did not want to run using CPU-based
        # DataLoaders
        if self.trainer.training:
            on_device_transform = self.on_device_train_transform[dataloader_idx]
        elif self.trainer.validating:
            on_device_transform = self.on_device_val_transform[dataloader_idx]
        elif self.trainer.testing:
            on_device_transform = self.on_device_test_transform[dataloader_idx]
        elif self.trainer.predicting:
            #TODO: predict_dataloader
            on_device_transform=self.on_device_pred_transform[dataloader_idx]
        elif self.trainer.evaluating:
            on_device_transform = self.on_device_test_transform[dataloader_idx]
        elif self.trainer.sanity_checking:
            on_device_transform = self.on_device_train_transform[dataloader_idx]
        else:
            print(
                'Unsure which stage we are in, defaulting to '
                'self.on_device_train_transform')
            on_device_transform = self.on_device_train_transform[dataloader_idx]

        # Skip on_device_transform if None
        if on_device_transform is None:
            return nag

        # Apply on_device_transform only once when in training mode and
        # if no test-time augmentation is required
        if self.trainer.training \
                or self.hparams.tta_runs is None \
                or self.hparams.tta_runs == 1 or \
                (self.trainer.validating and not self.hparams.tta_val):
            return on_device_transform(nag)

        # We return the input NAG as well as the augmentation transform
        # and the number of runs. Those will be used by
        # `LightningModule.step` to accumulate multiple augmented runs
        return nag, on_device_transform, self.hparams.tta_runs

    def __repr__(self):
        return f'{self.main_datamodule.__class__.__name__ } ' 

but when I run trainer=trainer.predict(model=model,datamodule=ConcatenateDatamodule(datamodule,datamodule)) I got this shady error : TypeError: An invalid dataloader was passed to Trainer.predict(dataloaders=...). Found S3DISDatamodule

I know this isn't exactly a superpoint_transformer question but it is highly related, Sincerely, Pierrick Bournez

drprojects commented 8 months ago

Hi, to be sure I understand:

Remark

to check how the model behave on a subsample of s3dis

The S3DISRoom is not a subsample of S3DIS. Both datasets have the same size and contain the same data. The difference is that S3DISRoom produces batch items in a room-by-room fashion, while S3DIS aggregates all rooms of the same Area together and produces random samplings of this large scene. We argue the latter is a better practice for indoor scene parsing, since a room subdivision is not always available in all acquisitions, can be ambiguous (think open spaces), and require human heuristics. Besides, we show in our paper that we SPT performs better on S3DIS than with room subdivisions in S3DISRoom, leveraging larger scenes at once.

gardiens commented 8 months ago

Hi, to be sure I understand:

  • are you trying to build a validation set that is made of the union of the validation sets of S3DIS and S3DISRoom or are you looking for a way of computing validation metrics for each datasets independently ?
  • do you need to track these validation metrics during training, or would a simple post-training evaluation suffice ?

I would like to compute the validation metrics of S3dis and S3DISRoom independently on the training phase of S3DIS. I'm interested in the impact of the dataset's range if a larger range is given for training than for validation. In this use-case, I highly suspect that the model is overfitting. Remark: In this case, it might be simpler to change the _on_device_validation_transform, but I'd like to apply this to a completely different datamodule (e.g. Kitti with class remapping, or S3disroom or whatever).

drprojects commented 8 months ago

I see. Off the top of my head, I do not see a clear, easy way of doing this.

It would seem lightning's multi-dataloader feature could be a good direction. But it is unclear whether it concatenates dataloaders into a single one, or computes metrics for each ? I don't find the docs very clear in this regard, but I think it concatenates/mixes them. So, not what you are trying to do. Have you tested this functionality in another project ?

However, you only seem to be interested in final performance, and tracking the validation metrics on two separate datasets as training goes does not seem necessary. If so, you can just train your model and evaluate on different datasets afterwards. Much simpler use case then.

Also, regarding your intuition, if I understood you correctly, I may have some information to share. I you are saying you suspect SPT may overfit if given S3DIS batch samples with too-large radius, you are correct. The resulting superpoint graphs are so large that the model can recognize which part of the buildings it is in. We found that the 7 m radius was a good trade-off in this regard. As explained in our paper, the superpoint paradigm reduces the size of the problem so much that overfitting is never far. It is likely that SPT would benefit from training on much larger 3D datasets.

drprojects commented 8 months ago

May I close this issue ?