Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.07k stars 60 forks source link

recursive error fix #626

Closed k223kim closed 5 days ago

k223kim commented 1 week ago
Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos and docs improvements) - [ ] Did you read the [contributor guideline](https://github.com/Lightning-AI/pytorch-lightning/blob/main/.github/CONTRIBUTING.md), Pull Request section? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests?

What does this PR do?

Fixes #461.

PR review

Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

k223kim commented 1 week ago

The issue is that, HF's ModelOutput inherits dict therefore, when calling is_collection(x), it is always True resulting in an infinite loop within to_printable and prettyprint. Currently, I have added such condition to filter out those xs to avoid recursive error.

        if type(x) not in (
            dict,
            list,
            str,
            int,
            bool,
            tuple,
            torch.dtype,
            float,
            dtypes.floating,
            devices.Device,
            torch.memory_format,
        ):
            return x

I am uncertain regarding two points:

Regarding prettyprint, I think we should come up with a way to handle types like HF's ModelOutput. I tried to return m(str(x)) in such cases, but that resulted in the following error:

Exception has occurred: NameError
name 'BaseModelOutput' is not defined
  File "/Users/kaeunkim/lightning-thunder/thunder/__init__.py", line 600, in fn_
    result = cache_entry.computation_fn(*inps)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

That is why currently I am returning m(baseutils.print_base_printable(x)) in prettyprint (which does not work as it is not a base_printable)

k223kim commented 1 week ago

I have updated so that we check the type in tree_flatten. Not sure if there is a way to handle the if condition with only isinstance. (cc. @t-vi , @riccardofelluga)

kshitij12345 commented 5 days ago

Supporting dataclasses is high prio to us, so we would not want to block https://github.com/Lightning-AI/lightning-thunder/pull/632 if we decide that staying with the solution there is the way we want to go now. @kshitij12345 WDYT about not flattening the world here?

I think this fix is good and should not block #632. Thanks @k223kim @t-vi