Lightning-Universe / lightning-flash

Your PyTorch AI Factory - Flash enables you to easily configure and run complex AI recipes for over 15 tasks across 7 data domains
https://lightning-flash.readthedocs.io
Apache License 2.0
1.74k stars 212 forks source link

Models trained on the Meta Learning interface do not support test functions other than accuracy #1379

Open castelojb opened 2 years ago

castelojb commented 2 years ago

🐛 Bug

First of all, congratulations for working at a high level with the interface using learn2learn. The bug is that when a model is trained using the meta learning method and then submitted to trainer.test, it does not use other test functions present in test_metrics. Also, it doesn't give me the results of a prediction in the usual way, model(x), it returns None

To Reproduce

seed_everything(42)

datamodule = ImageClassificationData.from_data_frame(
     "path",
     "class",
     train_data_frame=train,
     val_data_frame = validate,
     test_data_frame = test,
      transform_kwargs=dict(image_size=(128, 128)),
     batch_size=2
     )

model = ImageClassifier(
    backbone="resnet18",
    training_strategy="maml",
    pretrained=False,
    training_strategy_kwargs={
        "epoch_length": 50,
        "meta_batch_size": 2,
        "num_tasks": 50,
        "test_num_tasks": 50,
        "ways": datamodule.num_classes,
        "shots": 2,
        "test_ways": 2,
        "test_shots": 1,
        # "test_queries": 15,
    },
    optimizer=torch.optim.Adam,
    learning_rate=0.001,
)

trainer = flash.Trainer(
    max_epochs=50,
    precision=16,
    accelerator="ddp_shared",
    gpus=int(torch.cuda.is_available()),
)

trainer.fit(model, datamodule=datamodule )

# 5. Save the model!
trainer.save_checkpoint(path)

# read from the reading from the model saved above
model_trn = ImageClassifier.load_from_checkpoint(path)

from torchmetrics import F1Score, Precision, Recall, Accuracy, 

model_trn .test_metrics['F1-Score'] = F1Score(6, average='macro')

model_trn .test_metrics['Precision'] = Precision(6, average='macro')

model_trn .test_metrics['Recall'] = Recall(6, average='macro')

trainer.test(model_trn, dataloaders=datamodule .test_dataloader())

>>> [{'test_accuracy': 0.8799999952316284, 'test_loss': 0.280563086271286}]

test_loader = datamodule.test_dataloader()

data_iter = iter(test_loader)
sample_ = next(data_iter)
input = sample_['input']

model_trn(input)
>>> None

Expected behavior

I performed the same testing process on a trained resnet18 model without meta learning and it came out as expected

model_trn = ImageClassifier.load_from_checkpoint(path)

from torchmetrics import F1Score, Precision, Recall, 

model_trn.test_metrics['F1-Score'] = F1Score(6, average='macro')

model_trn.test_metrics['Precision'] = Precision(6, average='macro')

model_trn.test_metrics['Recall'] = Recall(6, average='macro')

trainer.test(model_trn, dataloaders=datamodule.test_dataloader())

>>> [{'test_F1-Score': 0.5142857432365417,
  'test_Precision': 0.4933333694934845,
  'test_Recall': 0.5800000429153442,
  'test_accuracy': 0.5882353186607361,
  'test_cross_entropy': 1.0256577730178833}]

test_loader = datamodule.test_dataloader()

data_iter = iter(test_loader)
sample_ = next(data_iter)
input = sample_['input']

model_trn(input).shape
>>> (2, 6)

Environment

Everything has been tested on google colab with GPU. I used the following commands to install the libs

!pip install 'git+https://github.com/PyTorchLightning/lightning-flash.git'
!pip install 'git+https://github.com/PyTorchLightning/lightning-flash.git#egg=lightning-flash[image]'

I don't know why, but the import and the use of the keyword "maml" only work if you perform these installations previously

!pip install learn2learn
!pip install Pillow==7.1.2
Borda commented 1 year ago

@castelojb would you be interested in helping us to debug/extend this case? :otter:

castelojb commented 1 year ago

Sounds good to me! But how can I help with this?

Borda commented 1 year ago

Sounds good to me! But how can I help with this?

lets ping/pair with @krshrimali or @ethanwharris to give you inside... :)