Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.15k stars 77 forks source link

Functional JIT loading closures sharp edge #132

Open riccardofelluga opened 6 months ago

riccardofelluga commented 6 months ago

Strategy required

This issue resumes form PR2410, we need to decide on the strategy for closures sharp edge. Let's start simple, I think we can all agree that this is a sharp edge if we jit foo:

x = 5
def foo():
      return x

And that's because we are using a variable outside of the jitted scope. However, here is where things get interesting: should we consider the following a sharp egde?

def foo(x):
    def bar():
        return x
    return bar()

I assume that, since we captured x when jitting foo, this should not be a sharp edge for bar because the variable was declared in the scope(or in this case captured). To fix such a case we can remember what variables we captured and then look them up when we see a freevar. However, @mruberry has an interesting point, what happens in the case that the variable gets deleted? How can we deal with something like:

def foo():
  a = 5

  def bar():
    nonlocal a
    del a

  bar()

  return a

In conclusion, what do you think should be the definition of sharp edge in this context?

cc @apaz-cli @t-vi @mruberry

mruberry commented 6 months ago

I was thinking more about this and I wonder if we should just define the thunder.functional.jit to bind all global and nonlocal accesses at compile-time. This would be consistent with how it wants to treat functions. That way if a global int with the value of 5 is loaded, the 5 would just become a constant on future calls to thunder.functional.jit. This behavior is similar to numba's, and I think the functional jit and numba have a lot of UX overlap.

apaz-cli commented 5 months ago

@riccardofelluga

My opinion is that it the interpreter should work how cpython does, whenever possible. If we can't, we should invalidate the cache and re-trace. It remains to be seen how far we want to go with this idea, maybe it's slow in some situations, but I think this behavior is fairly easy to emulate. If we can track global variable inputs, we can certainly model any __closures__ that are accessed as inputs, and invalidate the cache accordingly. Provenance tracking shouldn't prove too big an issue. After all, the interpreter accesses and uses these values in a particular order, and we're tracking their downstream uses.

Notably though, there are two cases. Either the closure is attached to the function we're compiling (in the co_consts, etc), or it's attached to some object we end up encountering later. Both need to be modeled as inputs to the function for the purposes of cache invalidation if we want this to work.

I would argue that the first case is a subset of the second case. The second case is the one that we should decide if we want to support. Personally, I'm sort of unsure why loading nonlocals is a sharp edge as well. We know exactly how to do it, and it isn't any more difficult than loading globals. If loading nonlocals is a sharp edge, then the same should be true for loading globals and accessing attributes. It's just a matter of provenance tracking, and deciding when to invalidate the cache. But the cache invalidation decision should ultimately be based on how the inputs are used, not where they're coming from. As an argument, from the code object, a global, a nonlocal, or otherwise.

An example of the former case:

def outer():
  lst = None
  def inner():
    nonlocal lst
    if lst is None: # CONTROL FLOW CHANGES BASED ON NONLOCAL INPUT
      lst = []
    lst.append("#")
    return lst
  tj = thunder.jit(inner) # Loads the function from co_consts
  tj() # Returns ['#']
  return tj() # Returns ['#', '#']

thunder.jit(outer)() # Returns ['#', '#']

And an example of the latter:

global_lst = []
def unrelated():
  lst = None
  def inner():
    nonlocal lst
    if lst is None: # CONTROL FLOW CHANGES BASED ON NONLOCAL INPUT
      lst = global_lst
    lst.append("#")
    return lst
  return inner

def foo():
  inner_fn = unrelated() # Encounters the function with nonlocals as result of a LOAD_GLOBAL, which we can't have known the result of beforehand.
  lst1 = inner_fn()
  print(lst1) # ['#']
  lst2 = inner_fn()
  print(lst2) # ['#', '#']
  print(lst1 is global_lst) # True
  print(lst2 is global_lst) # True

thunder.jit(foo)()
mruberry commented 5 months ago

I think it's a sharp edge for the functional jit, @apaz-cli, because it's functional — no surprise inputs allowed.