Open awaelchli opened 1 year ago
@awaelchli I can take this up if its still open ?
@Atharva-Phatak Yes, it's in discussion phase. There are three options listed. B) is the most involved option, where the dtype would have to be inferred from the input data. It was proposed by @justusschock in the linked issue. I think it is possible, however, it requires validation and error messaging when the input contains several dtypes.
Option C is the one that leaves the user with the most control, as they can decide what happens. It's also closest to the default torch experience.
What if we let the user define a output_transform_fn
, which will be applied to the tensor in the forward pass? Let the user handle it, maybe?
By default it could be output_transform_fn = lambda x : x
Don't you think this will be more customizable ?
CC : @carmocca @awaelchli
@Atharva-Phatak I wouldn't do that. Passing lambdas is bad with pickling and I am not a fan of passing functions for that in general. This is too much of the trainer approach we want to break out of with lite.
So far the approach was that if you want to customize something, you subclass and overwrite it. So the option proposed by @awaelchli is not to do anything so that users could do it on their own afterwards. So they wouldn't need to pass a function at all but could just do it on the output afterwards.
@justusschock Thank you for letting me know :). I did not it had issues with pickling :(. So what should be the final approach from the ones proposed?
I did not it had issues with pickling :(.
That actually depends on which strategy you use. It could for example bee an issue with DDP spawn as lambdas in general cannot be pickled
Regarding the final approach that is yet to be decided. I personally tend towards option B)
what do you think @awaelchli @carmocca ?
I think the lambda is not needed, as in Lite the user is in full control of loop and forward pass. A user can add conversion of dtype anywhere they want already. I tend towards A) or C) as these are simpler approaches. For B) I'm thinking it could be confusing that
This inconsistency could be confusing, as the user has to "remember' that Lightning does this in the back.
EDIT: I meant A+C not A+B as I originally wrote. I am in favor of A or C
If we do B) we need to do it consistently. I.e. for every tensor in the input store the dtype and afterwards for everyone that was a floating point tensor in the beginning do the conversion towards the original dtype.
That's what we have apply_to_collections
for.
To me, C) is the simplest solution (do nothing!).
I don't think we have a good usecase to believe any of them is superior, but C) is the simplest to implement and A) and B) have the problem of fighting against the framework if the user doesn't want this conversion, we would need an opt-out or to allow this inefficiency
whereas with C), the user can always choose to do this conversion later.
If in the future we see that this has become repeated boilerplate across our users, then we can re-think whether A) or B) is better.
š Feature
Discussion: Should the forward of LightningLite's
_LiteModule
wrapper convert the outputs back to default precision?https://github.com/Lightning-AI/lightning/blob/a3258391416c8f96c05c044562f9ad7f6880733b/src/lightning_lite/wrappers.py#L115-L118
Motivation
Original motivation was to have a precision-agnostic module wrapper such that user does not need to convert outputs when switching to different precision backends. However, on the other hand it takes away control from the user.
Pitch
A) Keep as is (convert output to default type) B) Convert it back to the type the input had (here, input refers to the input of the wrapper, NOT the inner module) C) Do nothing. Return the output as it was returned by the inner forward
Additional context
Raised in https://github.com/Lightning-AI/lightning/pull/14792#discussion_r977698959
If you enjoy Lightning, check out our other projects! ā”
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.
Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.
Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging PyTorch Lightning, Transformers, and Hydra.