Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.98k stars 3.35k forks source link

Precision output conversion in Lite wrappers #14859

Open awaelchli opened 1 year ago

awaelchli commented 1 year ago

šŸš€ Feature

Discussion: Should the forward of LightningLite's _LiteModule wrapper convert the outputs back to default precision?

https://github.com/Lightning-AI/lightning/blob/a3258391416c8f96c05c044562f9ad7f6880733b/src/lightning_lite/wrappers.py#L115-L118

Motivation

Original motivation was to have a precision-agnostic module wrapper such that user does not need to convert outputs when switching to different precision backends. However, on the other hand it takes away control from the user.

Pitch

A) Keep as is (convert output to default type) B) Convert it back to the type the input had (here, input refers to the input of the wrapper, NOT the inner module) C) Do nothing. Return the output as it was returned by the inner forward

Additional context

Raised in https://github.com/Lightning-AI/lightning/pull/14792#discussion_r977698959


If you enjoy Lightning, check out our other projects! āš”

Atharva-Phatak commented 1 year ago

@awaelchli I can take this up if its still open ?

awaelchli commented 1 year ago

@Atharva-Phatak Yes, it's in discussion phase. There are three options listed. B) is the most involved option, where the dtype would have to be inferred from the input data. It was proposed by @justusschock in the linked issue. I think it is possible, however, it requires validation and error messaging when the input contains several dtypes.

Option C is the one that leaves the user with the most control, as they can decide what happens. It's also closest to the default torch experience.

Atharva-Phatak commented 1 year ago

What if we let the user define a output_transform_fn, which will be applied to the tensor in the forward pass? Let the user handle it, maybe?

By default it could be output_transform_fn = lambda x : x

Don't you think this will be more customizable ?

CC : @carmocca @awaelchli

justusschock commented 1 year ago

@Atharva-Phatak I wouldn't do that. Passing lambdas is bad with pickling and I am not a fan of passing functions for that in general. This is too much of the trainer approach we want to break out of with lite.

So far the approach was that if you want to customize something, you subclass and overwrite it. So the option proposed by @awaelchli is not to do anything so that users could do it on their own afterwards. So they wouldn't need to pass a function at all but could just do it on the output afterwards.

Atharva-Phatak commented 1 year ago

@justusschock Thank you for letting me know :). I did not it had issues with pickling :(. So what should be the final approach from the ones proposed?

justusschock commented 1 year ago

I did not it had issues with pickling :(.

That actually depends on which strategy you use. It could for example bee an issue with DDP spawn as lambdas in general cannot be pickled

Regarding the final approach that is yet to be decided. I personally tend towards option B)

what do you think @awaelchli @carmocca ?

awaelchli commented 1 year ago

I think the lambda is not needed, as in Lite the user is in full control of loop and forward pass. A user can add conversion of dtype anywhere they want already. I tend towards A) or C) as these are simpler approaches. For B) I'm thinking it could be confusing that

This inconsistency could be confusing, as the user has to "remember' that Lightning does this in the back.

EDIT: I meant A+C not A+B as I originally wrote. I am in favor of A or C

justusschock commented 1 year ago

If we do B) we need to do it consistently. I.e. for every tensor in the input store the dtype and afterwards for everyone that was a floating point tensor in the beginning do the conversion towards the original dtype.

That's what we have apply_to_collections for.

carmocca commented 1 year ago

To me, C) is the simplest solution (do nothing!).

I don't think we have a good usecase to believe any of them is superior, but C) is the simplest to implement and A) and B) have the problem of fighting against the framework if the user doesn't want this conversion, we would need an opt-out or to allow this inefficiency

whereas with C), the user can always choose to do this conversion later.

If in the future we see that this has become repeated boilerplate across our users, then we can re-think whether A) or B) is better.