erikwijmans / Pointnet2_PyTorch

PyTorch implementation of Pointnet2/Pointnet++
The Unlicense
1.5k stars 340 forks source link

Fix bug: change the parameter name in PointNet2ClassificationSSG #103

Closed matinjugou closed 4 years ago

matinjugou commented 4 years ago

Change the parameter name in PointNet2ClassificationSSG.init from 'args' to 'hparams' so that users can use LightningModule.load_from_checkpoint to load the checkpoints like this:

import hydra

model_path = 'path/to/checkpoint'

@hydra.main("config/config.yaml")
def main(cfg):
    model = hydra.utils.instantiate(cfg.task_model, cfg)
    model = model.load_from_checkpoint(model_path)
    model.eval()
    model.freeze()

if __name__ == "__main__":
    main()
erikwijmans commented 4 years ago

Thank you for the PR! Merging.