nmichlo / disent

🧶 Modular VAE disentanglement framework for python built with PyTorch Lightning ▸ Including metrics and datasets ▸ With strongly supervised, weakly supervised and unsupervised methods ▸ Easily configured and run with Hydra config ▸ Inspired by disentanglement_lib
https://disent.michlo.dev
MIT License
124 stars 18 forks source link

[BUG]: Frameworks not loading as LightningModule when using trainer.fit #47

Closed yuvrajsingh2 closed 2 months ago

yuvrajsingh2 commented 1 year ago

Describe the bug When running the QuickStart example for frameworks, I encountered a problem loading the Ae module correctly. Initially, the callback line was giving me trouble but after removing it. I still could not load it as a LightningModule

After running this:

import pytorch_lightning as pl
from torch.utils.data import DataLoader
from disent.dataset import DisentDataset
from disent.dataset.data import XYObjectData
from disent.frameworks.ae import Ae
from disent.model import AutoEncoder
from disent.model.ae import DecoderConv64, EncoderConv64
from disent.dataset.transform import ToImgTensorF32
from disent.util import is_test_run  # you can ignore and remove this

# prepare the data
data = XYObjectData()
dataset = DisentDataset(data, transform=ToImgTensorF32())
dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True)

# create the pytorch lightning system
module: pl.LightningModule = Ae(
    model=AutoEncoder(
        encoder=EncoderConv64(x_shape=data.x_shape, z_size=6),
        decoder=DecoderConv64(x_shape=data.x_shape, z_size=6),
    ),
    cfg=Ae.cfg(optimizer='adam', optimizer_kwargs=dict(lr=1e-3), loss_reduction='mean_sum')
)

# train the model
trainer = pl.Trainer(logger=False, fast_dev_run=is_test_run())
trainer.fit(module, dataloader)

I got a TypeError. It happens with any framework I try.

    "name": "TypeError",
    "message": "`model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `Ae`",
    "stack": "---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
c:\\Users\\yuviu\\Desktop\\Uni Work\\Thesis\\disent\\experiment\\examples\\example.ipynb Cell 4 line 2
     <a href='vscode-notebook-cell:/c%3A/Users/yuviu/Desktop/Uni%20Work/Thesis/disent/experiment/examples/example.ipynb#W2sZmlsZQ%3D%3D?line=19'>20</a> # train the model
     <a href='vscode-notebook-cell:/c%3A/Users/yuviu/Desktop/Uni%20Work/Thesis/disent/experiment/examples/example.ipynb#W2sZmlsZQ%3D%3D?line=20'>21</a> trainer = pl.Trainer(logger=False, fast_dev_run=is_test_run())
---> <a href='vscode-notebook-cell:/c%3A/Users/yuviu/Desktop/Uni%20Work/Thesis/disent/experiment/examples/example.ipynb#W2sZmlsZQ%3D%3D?line=21'>22</a> trainer.fit(module, dataloader)

File c:\\Users\\yuviu\\anaconda3\\envs\\disent_env\\lib\\site-packages\\pytorch_lightning\\trainer\\trainer.py:529, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    495 def fit(
    496     self,
    497     model: \"pl.LightningModule\",
   (...)
    501     ckpt_path: Optional[str] = None,
    502 ) -> None:
    503     r\"\"\"Runs the full optimization routine.
    504 
    505     Args:
   (...)
    527 
    528     \"\"\"
--> 529     model = _maybe_unwrap_optimized(model)
    530     self.strategy._lightning_module = model
    531     _verify_strategy_supports_compile(model, self.strategy)

File c:\\Users\\yuviu\\anaconda3\\envs\\disent_env\\lib\\site-packages\\pytorch_lightning\\utilities\\compile.py:126, in _maybe_unwrap_optimized(model)
    124 if isinstance(model, pl.LightningModule):
    125     return model
--> 126 raise TypeError(
    127     f\"`model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `{type(model).__qualname__}`\"
    128 )

TypeError: `model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `Ae`"
}

To Reproduce Steps to reproduce the behaviour. Just running the quick start examples in the documentation

Expected behaviour A clear and concise description of what you expected to happen. The model to start training.

Additional context I installed the package from source in a conda env. I also tried a new env with v2.0.0 of lightning but still facing this issue.

nmichlo commented 2 months ago

Hi there, my apologies for missing this issue.

The problem is the pytorch_lightning import.

- import pytorch_lightning as pl

+ import lightning as pl