Closed FerusAndBeyond closed 3 years ago
Hi @FerusAndBeyond ,
From my view this sounds like a jupyter issue and nothing on our side. However, if you can provide a reproducible example, I would be more then happy to look into it :)
I've never gotten the issue for any other classes I edit in files aside from pytorch lightning modules.
Put this in a file pl.py
import torch.nn as nn
import torch
import pytorch_lightning as pl
mse_loss = nn.MSELoss()
class FNN(pl.LightningModule):
def __init__(self, in_dim):
super(FNN, self).__init__()
self.in_dim = in_dim
print("ok")
self.ff = nn.Linear(in_dim, in_dim)
def training_step(self, batch, batch_idx):
inputs, targets = batch
outputs = self(inputs)
loss = mse_loss(outputs, targets)
self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
return loss
def test_epoch_end(self, outputs):
print(outputs)
def validation_step(self, batch, batch_idx):
inputs, targets = batch
outputs = self(inputs)
loss = mse_loss(outputs, targets)
self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
return optimizer
def forward(self, x):
return self.ff(x)
Then open a notebook:
First cell:
%load_ext autoreload
%autoreload 2
import pandas as pd
from pl import FNN
Second cell
from torch.utils.data import Dataset, TensorDataset, DataLoader
from torch import nn
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
import pytorch_lightning as pl
import numpy as np
import pandas as pd
import torch
def get_early_stopping_cb(minimize=True):
return EarlyStopping(
monitor='val_loss',
mode='min' if minimize else "max",
patience=10,
)
def get_model_checkpoint_cb(minimize=True):
return ModelCheckpoint(
save_top_k=1,
monitor='val_loss',
mode='min' if minimize else "max",
save_weights_only=True
)
def train(model, train_loader, val_loader, minimize=True, max_epochs=100000):
checkpoint_callback = get_model_checkpoint_cb(minimize)
early_stopping = get_early_stopping_cb(minimize)
trainer = pl.Trainer(logger=False, max_epochs=max_epochs, callbacks=[checkpoint_callback, early_stopping])
trainer.fit(model, train_loader, val_loader)
model.load_state_dict(torch.load(checkpoint_callback.best_model_path)["state_dict"])
model.eval()
return model
def to_torch(x):
if isinstance(x, np.ndarray):
return torch.Tensor(x)
return x
def to_loader(x, y, *args, **kwargs):
ds = TensorDataset(to_torch(x), to_torch(y))
return DataLoader(ds, *args, **kwargs)
Third cell
import numpy as np
train_loader = to_loader(np.random.randn(1000, 10), np.random.randn(1000, 10))
val_loader = to_loader(np.random.randn(1000, 10), np.random.randn(1000, 10))
fnn = FNN(10)
train(fnn, train_loader, val_loader)
Now run the all cells but stop the kernel in the middle of the training process. Then go into pl.py
and edit the learning rate in the Adam optimizer and then rerun the last cell. After that I get:
TypeError: super(type, obj): obj must be an instance or subtype of type
I could reproduce this with the following Notebook: https://colab.research.google.com/drive/1Nx6y1boVk7fj7CgImtEkX8ZiVye_Pytr?usp=sharing
There are several issue with autoreload having issues with classes when already references of these classes exist like https://github.com/ipython/ipython/issues/11588 and https://github.com/ipython/ipython/issues/12553 .
That being said, I could resolve the error by changing super(FNN, self).__init__
to super().__init__
.
However, I still don't know why this happens but I'm pretty sure this is not on our side since we don't do anything special here (just normal inheritance).
super().__init__
works for me too. I will try that in the future, thanks. Alright, I'll close this if it's not pytorch-lightning related.
🐛 Bug
When editing a file and then rerunning a cell in a jupyter notebook using
autoreload
I getTypeError: super(type, obj): obj must be an instance or subtype of type
either immediately or after a few reruns for thepl.LightningModule
class. It's very annoying and I've gotten it multiple times in different projects. I must restart the kernel or sometimes even restart the entire notebook to fix the issue.