Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.11k stars 3.36k forks source link

Guarantee call order for callbacks #10260

Open z-a-f opened 2 years ago

z-a-f commented 2 years ago

🚀 Feature

It is often beneficial to have an explicit guarantee that the callbacks will be called in the order that they were added to the callback list. If I understand it right, the order right now is preserved, but it's more of an implementation detail, rather than hard rule.

Motivation and Pitch

Our use case is quantization + sparsity. During the initial setup we prepare the model for quantization by adding quantization configurations to it. We also prepare the modules for sparsification by adding parametrizations. However, once the training is complete, the merger of the sparse masks and the weights has to happen BEFORE the quantization. Otherwise, the quantization tries to quantize the non-masked tensor.

Another possible use case is when we let the user decide on the order of quantization and masking. If ther user at some point decides to try "quantize -> sparsify" (instead of currently implemented "sparsify -> quantize"), we would want the user to be able to do that.


If you enjoy Lightning, check out our other projects! âš¡

carmocca commented 2 years ago

We do guarantee this already, with the only exception of the ModelCheckpoint callback which gets moved to last.

Although we recommend not relying on it if possible.

Are you asking this after reading some piece of docs that says otherwise?

z-a-f commented 2 years ago

No, I chatted with @kandluis and he mentioned that the order is preserved, but not guaranteed. If I understand it right, the order is due to current implementation. However, applications such as quant combined with sparsity requires guarantees that the order is preserved -- the functionality and the results will be different depending on the order. Without explicit guarantees that the order is preserved, we won't be able to guarantee deterministic behavior.

qq -- is there a reason why there is a recommendation not to rely on the callback order?

tchaton commented 2 years ago

Dear @z-a-f,

Lightning executes the callbacks in the order there were provided. If you provide them in the same way on reload, the behaviour should be deterministic.

However, we are currently discussing around reloading as quant / sparsity need to be done before / after the model is restored.

@Borda

Best, T.C

z-a-f commented 2 years ago

However, we are currently discussing around reloading as quant / sparsity need to be done before / after the model is restored.

In that case, it would be worth thinking about the order in which quant and sparse are called as well. We can see usecases for either "quantize-after-sparsifying" or "sparsify-after-quantizing". Both cases are valid, and (imho) should be deterministically controlled by a user.

Will be happy to discuss it in more details, and/or provide callbacks for sparse/quant

kandluis commented 2 years ago

@tchaton , if we want to explicitly guarantee ordering (aside from ModelCheckpoint), which IMO seems reasonable, we should update some of the best practices in this documentation here: https://pytorch-lightning.readthedocs.io/en/latest/extensions/callbacks.html#best-practices

Specifically:

Your callback should not rely on the behavior of other callbacks in order to work properly. Whenever possible, your callbacks should not depend on the order in which they are executed.

It seems fine to me to guarantee that callbacks will be called in the order the user has provided them. A few questions that come to mind:

(1) Do we need to guarantee this across workers/nodes as well, or is within-process order guarantee sufficient? (2) Do we lose anything useful by keeping this guarantee (eg, would Lightning ever find it useful to re-order callbacks, etc?)

tchaton commented 2 years ago

Hey @kandluis,

(1): The workers / nodes doesn't have an impact on the Trainer arguments (very unlikely at least for most users). (2): Lightning won't re-order the callbacks. This would be a break of trust for the users.

Note: Lightning already tracks the state of the callbacks for restoration. But we could enforce an order state and raise a warning if it is not provided back in the same order. What are your thoughts @awaelchli ?

Best, T.C

kargarisaac commented 1 year ago

We do guarantee this already, with the only exception of the ModelCheckpoint callback which gets moved to last.

Although we recommend not relying on it if possible.

Are you asking this after reading some piece of docs that says otherwise?

I have problem with logging the model using lightning mlflow logger. I want to log the saved model using mlflow.log_artifact() after the ModelCheckpoint logged the best model in each epoch. I cannot find any method which will be executed after the ModelCheckpoint. What do you recommend? I use pytorch-lightning ==1.7.7. I want to log the model or maybe convert it to onnx or torchscript and log in mlflow and load it from my endpoint. I don't want to copy the code for model definition or anything else from my training code. Do you also have any recommendation for that?

salvaRC commented 1 year ago

I have the same issue as @kargarisaac . Is there any way to run a callback after ModelCheckpoint in order to upload those checkpoints to the cloud?

CobSammich commented 3 months ago

I have the same issue as @kargarisaac . Is there any way to run a callback after ModelCheckpoint in order to upload those checkpoints to the cloud?

Did you ever find a way to do this? @kargarisaac