learnables / learn2learn

A PyTorch Library for Meta-learning Research
http://learn2learn.net
MIT License
2.59k stars 348 forks source link

The lightning_maml method has a risk of out of memory #432

Open lian-xiao opened 3 weeks ago

lian-xiao commented 3 weeks ago

In the pytorch-lightning2.2.1 environment, Maml is implemented by wrapping the meta_learn function with torch.enable_grad, which makes the model's outer losses uncomputed during validation, resulting in memory overflow. This can be avoided by the peripheral loss calculation procedure of the torch.no_grad() packaging validation procedure.