learnables / learn2learn

A PyTorch Library for Meta-learning Research
http://learn2learn.net
MIT License
2.61k stars 350 forks source link

LightningPrototypicalNetworks example fails #346

Closed farzam-khodajoo closed 2 years ago

farzam-khodajoo commented 2 years ago

Hi there, why this examples is not working? why cross_entropy fails?

import learn2learn as l2l
import pytorch_lightning as pl

tasksets = l2l.vision.benchmarks.get_tasksets('omniglot', root="D:/datasets/omniglot/")
features = l2l.vision.models.OmniglotCNN()
protonet = LightningPrototypicalNetworks(features)
episodic_data = EpisodicBatcher(tasksets.train, tasksets.validation, tasksets.test)
trainer = pl.Trainer()
trainer.fit(protonet, episodic_data)

cell ouput:

in Trainer._call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
    648         return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
    649     else:
--> 650         return trainer_fn(*args, **kwargs)
    651 # TODO(awaelchli): Unify both exceptions below, where `KeyboardError` doesn't re-raise
    652 except KeyboardInterrupt as exception:
...
   3012 if size_average is not None or reduce is not None:
   3013     reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3014 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)

IndexError: Target 1 is out of bounds.
farzam-khodajoo commented 2 years ago

@seba-1511 Hi sorry for direct mention but I still have struggle getting lightning module to work

seba-1511 commented 2 years ago

Hello @arshamkhodajoo,

Could you put this in a colab? A priori this should work, but it looks like either the support or query targets have too large indices.

farzam-khodajoo commented 2 years ago

@seba-1511 Thank for your response link to Colab: https://colab.research.google.com/drive/15nJFuO0c2f_bGQvrG0XT14ZjaPD6-RSW?usp=sharing

I had to copy/paste code for EpisodicBatcher because this happens when you try to import from learn2learn.utils.lightning: image

turns out this line throws error because pl.callbacks.ProgressBar is not valid anymore: https://github.com/learnables/learn2learn/blob/f099ddc9ce0c10cff901ecb1acee2838d171272e/learn2learn/utils/lightning.py#L70