In the pytorch-lightning2.2.1 environment, Maml is implemented by wrapping the meta_learn function with torch.enable_grad, which makes the model's outer losses uncomputed during validation, resulting in memory overflow. This can be avoided by the peripheral loss calculation procedure of the torch.no_grad() packaging validation procedure.
In the pytorch-lightning2.2.1 environment, Maml is implemented by wrapping the meta_learn function with torch.enable_grad, which makes the model's outer losses uncomputed during validation, resulting in memory overflow. This can be avoided by the peripheral loss calculation procedure of the torch.no_grad() packaging validation procedure.