Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.18k stars 77 forks source link

A reference cycle is detected related to the ThunderModule #1034

Open kiya00 opened 2 months ago

kiya00 commented 2 months ago

Note: If you have a model or program that is not supported yet but should be, please use the program coverage template.

🐛 Bug

To Reproduce

import torch
import thunder
import weakref
import gc

mod = torch.nn.ReLU()
ref = weakref.ref(mod, lambda _: print("mod deleted!"))
opt_mod = thunder.jit(mod)
# opt_mod = torch.compile(mod)
ref_opt_mod = weakref.ref(opt_mod, lambda _: print("opt_mod deleted!"))
x = torch.randn(10, 10)
refx = weakref.ref(x, lambda _: print("x deleted!"))
opt_mod(x)
del x
del mod
del opt_mod
# gc.collect()
print("done!")  # done!

with the line of torch.comile(), it outputs:

x deleted!
opt_mod deleted!
mod deleted!
done!

with thunder it outputs:

x deleted!
done!
mod deleted!
opt_mod deleted!

@kshitij12345 detected there's a reference cycle:

import torch
import thunder
import weakref
import gc

mod = torch.nn.ReLU()
ref = weakref.ref(mod, lambda _: print("mod deleted!"))
opt_mod = thunder.jit(mod)
# opt_mod = torch.compile(mod)
ref_opt_mod = weakref.ref(opt_mod, lambda _: print("opt_mod deleted!"))
x = torch.randn(10, 10)
refx = weakref.ref(x, lambda _: print("x deleted!"))
opt_mod(x)
del x
del mod
del opt_mod
# gc.collect()
print("done!")  # done!

if ref_opt_mod() is not None:
    import refcycle
    graph = refcycle.snapshot()

    try:
        cycle = graph.shortest_cycle(ref_opt_mod())
        print("CYCLE FOUND FROM MOD")
    except ValueError:
        print("NO CYCLE FROM MOD")
        pass
# Save the latest cycle.
    cycle.export_json("cycle.json")
    cycle.export_image("cycle.png")

image

cc @apaz-cli

kshitij12345 commented 2 months ago

Second ref cycle (one of the object here holds onto the user module) - image

Repro

def foo():
    import torch
    import thunder
    import weakref
    import gc

    mod = torch.nn.ReLU()
    ref_mod = weakref.ref(mod, lambda _: print("mod deleted!"))
    opt_mod = thunder.jit(mod)
    ref_opt_mod = weakref.ref(opt_mod, lambda _: print("opt_mod deleted!"))
    x = torch.randn(10, 10)
    refx = weakref.ref(x, lambda _: print("x deleted!"))
    opt_mod(x)
    del x
    del mod
    del opt_mod
    # gc.collect()
    print("done!")  # done!

    if ref_mod() is not None:
        import refcycle
        graph = refcycle.snapshot()

        try:
            cycle = graph.shortest_cycle(ref_mod())
            print("CYCLE FOUND FROM MOD")
        except ValueError:
            print("NO CYCLE FROM MOD")
            pass

        # More cycles are found here
        for anc in graph.ancestors(ref_mod()):
            try:
                cycle = graph.shortest_cycle(anc)
                print("CYCLE FOUND FROM ANCESTOR")
                print(anc)

                # Check the cycle from above
                # print(anc["prologue"].__wrapped__.__wrapped__.__wrapped__.__globals__["prologue"] is anc["prologue"])  # True
                # print(anc["prologue"].__wrapped__.__wrapped__.__wrapped__.__globals__["__function_obj"])
                break
            except ValueError:
                pass

        # for obj in cycle:
        #     print(obj)

        # Save the latest cycle.
        cycle.export_json("cycle.json")
        cycle.export_image("cycle.png")

foo()
t-vi commented 1 month ago

So regarding the priority, as discussed in slack: From what I can see, this cycle keeps modules going out of scope from being collected. Not nice, but for the most part, I don't think we will be compiling short-lived modules, so it might not be a game-breaker right now.

I looked at this a bit. In general:

WDYT?

IvanYashchuk commented 1 month ago

Not nice, but for the most part, I don't think we will be compiling short-lived modules, so it might not be a game-breaker right now.

It's a game-breaker because it blocks the usage of the Thunder-optimized dropout layer in a larger module as

self.dropout = thunder.jit(nn.Dropout(p=0.5))
t-vi commented 1 month ago

It's a game-breaker because it blocks the usage of the Thunder-optimized dropout layer in a larger module as

I would like to understand this more. Is it a game-breaker because you disagree that it is not as relevant for long-lived modules or because you expect the modules to be short-lived?

IvanYashchuk commented 1 month ago

I'm sorry I confused this issue with https://github.com/Lightning-AI/lightning-thunder/issues/1074. I don't have an important use case for fixing this bug.