Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.14k stars 70 forks source link

implement zip lookaside in Python interpreter (enables e.g. thunder.jit with zip from LitGPT LLaMAMoE) #284

Open IvanYashchuk opened 4 months ago

IvanYashchuk commented 4 months ago

🐛 Bug

Here's a simplified version of LitGPT's LLaMAMoE without data-dependent shapes and it fails somewhere in the general jit:

NotImplementedError: unpacking from OPAQUE <slot wrapper '__next__' of 'zip' objects> ProvenanceRecord(

To reproduce:

import torch
import thunder
from torch import nn

class Test(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.n_expert = 8
        self.n_expert_per_token = 2
        self.C = 2
        self.gate = nn.Linear(self.C, self.n_expert, bias=False)
        self.experts = nn.ModuleList(nn.Linear(2, 2, bias=False) for _ in range(self.n_expert))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)
        x = x.view(-1, C)  # (B*T, C)
        router = self.gate(x)  # (B*T, n_expert)
        probs, indices = torch.topk(router, self.n_expert_per_token)  # (B*T, n_expert_per_token)
        probs = probs.softmax(dim=1, dtype=torch.float).to(dtype=x.dtype)
        masks = indices.unsqueeze(-1) == torch.arange(self.n_expert, device=x.device)
        masks = masks.permute(2, 0, 1)  # (n_expert, B*T, n_expert_per_token)
        y = torch.zeros_like(x)  # (B*T, C)
        for (mask, expert) in zip(masks, self.experts):
            token_idx, expert_idx = torch.arange(B*T, device=x.device), torch.arange(B*T, device=x.device)
            pprobs = probs[token_idx, expert_idx]
            pprobs = pprobs.unsqueeze(-1)
            eexpert = expert(x[token_idx])
            y = torch.index_add(y, 0, token_idx, pprobs * eexpert)
        return y.view(B, T, C)

model = Test()
model = thunder.jit(model)

x = torch.randn(2, 3, 2)
y = model(x)

raises:

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1272, in unpack_inputs.<locals>.unpack(v)
   1271 try:
-> 1272     from_provenance(p.history)
   1273 except Exception as e:

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1264, in unpack_inputs.<locals>.unpack.<locals>.from_provenance(provenance, new_output)
   1263     raise NotImplementedError(f"Unpacking from {inst} {provenance}")
-> 1264 res = unpack_fn(provenance, new_output=new_output)
   1265 provenance.proxy = res

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1195, in unpack_inputs.<locals>.unpack.<locals>.from_binary_subscr(provenance, new_output)
   1194 def from_binary_subscr(provenance, *, new_output=False):
-> 1195     inputs = [from_provenance(i, new_output=True) for i in provenance.inputs]
   1196     obj, idx = inputs

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1195, in <listcomp>(.0)
   1194 def from_binary_subscr(provenance, *, new_output=False):
-> 1195     inputs = [from_provenance(i, new_output=True) for i in provenance.inputs]
   1196     obj, idx = inputs

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1264, in unpack_inputs.<locals>.unpack.<locals>.from_provenance(provenance, new_output)
   1263     raise NotImplementedError(f"Unpacking from {inst} {provenance}")
-> 1264 res = unpack_fn(provenance, new_output=new_output)
   1265 provenance.proxy = res

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1195, in unpack_inputs.<locals>.unpack.<locals>.from_binary_subscr(provenance, new_output)
   1194 def from_binary_subscr(provenance, *, new_output=False):
-> 1195     inputs = [from_provenance(i, new_output=True) for i in provenance.inputs]
   1196     obj, idx = inputs

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1195, in <listcomp>(.0)
   1194 def from_binary_subscr(provenance, *, new_output=False):
-> 1195     inputs = [from_provenance(i, new_output=True) for i in provenance.inputs]
   1196     obj, idx = inputs

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1264, in unpack_inputs.<locals>.unpack.<locals>.from_provenance(provenance, new_output)
   1263     raise NotImplementedError(f"Unpacking from {inst} {provenance}")
-> 1264 res = unpack_fn(provenance, new_output=new_output)
   1265 provenance.proxy = res

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1178, in unpack_inputs.<locals>.unpack.<locals>.from_load_attr(provenance, new_output)
   1177 is_pure = False
-> 1178 inputs = [from_provenance(i, new_output=True) for i in provenance.inputs]
   1179 if new_output:

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1178, in <listcomp>(.0)
   1177 is_pure = False
-> 1178 inputs = [from_provenance(i, new_output=True) for i in provenance.inputs]
   1179 if new_output:

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1264, in unpack_inputs.<locals>.unpack.<locals>.from_provenance(provenance, new_output)
   1263     raise NotImplementedError(f"Unpacking from {inst} {provenance}")
-> 1264 res = unpack_fn(provenance, new_output=new_output)
   1265 provenance.proxy = res

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1195, in unpack_inputs.<locals>.unpack.<locals>.from_binary_subscr(provenance, new_output)
   1194 def from_binary_subscr(provenance, *, new_output=False):
-> 1195     inputs = [from_provenance(i, new_output=True) for i in provenance.inputs]
   1196     obj, idx = inputs

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1195, in <listcomp>(.0)
   1194 def from_binary_subscr(provenance, *, new_output=False):
-> 1195     inputs = [from_provenance(i, new_output=True) for i in provenance.inputs]
   1196     obj, idx = inputs

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1264, in unpack_inputs.<locals>.unpack.<locals>.from_provenance(provenance, new_output)
   1263     raise NotImplementedError(f"Unpacking from {inst} {provenance}")
-> 1264 res = unpack_fn(provenance, new_output=new_output)
   1265 provenance.proxy = res

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1241, in unpack_inputs.<locals>.unpack.<locals>.from_opaque(provenance, new_output)
   1232     return from_provenance(
   1233         ProvenanceRecord(
   1234             PseudoInst.LOAD_ATTR,
   (...)
   1239         )
   1240     )
-> 1241 raise NotImplementedError(f"unpacking from OPAQUE {fn.value} {provenance}")

NotImplementedError: unpacking from OPAQUE <slot wrapper '__next__' of 'zip' objects> ProvenanceRecord(
  i1 = INPUT_FN()
  i2 = LOAD_ATTR(i1, '__dict__')
  i3 = BINARY_SUBSCR(i2, '_modules')
  i4 = BINARY_SUBSCR(i3, 'experts')
  i5 = INPUT_ARGS()
  i6 = BINARY_SUBSCR(i5, 0)
  i7 = LOAD_ATTR(i6, '__getattr__')
  i8 = LOAD_ATTR(i7, '__func__')
  i9 = Instruction(opname='CALL_FUNCTION_KW', opcode=141, arg=2, argval=2, argrepr='', offset=102, starts_line=None, is_jump_target=False)()
  i10 = LOAD_ATTR(i1, 'n_expert_per_token')
  i11 = BINARY_SUBSCR(i3, 'gate')
  i12 = LOAD_ATTR(i11, '__dict__')
  i13 = BINARY_SUBSCR(i12, '_parameters')
  i14 = BINARY_SUBSCR(i13, 'bias')
  i15 = BINARY_SUBSCR(i13, 'weight')
  i16 = BUILD_TUPLE('view', i6)
  i17 = OPAQUE(i8, i16, CONSTANT({}))
  i18 = LOAD_ATTR(i17, 'func')
  i19 = Instruction(opname='BINARY_ADD', opcode=23, arg=None, argval=None, argrepr='', offset=10, starts_line=None, is_jump_target=False)()
  i20 = BINARY_SUBSCR(i19, 2)
  i21 = BINARY_SUBSCR(i19, 1)
  i22 = LOAD_ATTR(i17, 'args')
  i23 = BINARY_SUBSCR(i22, 0)
  i24 = BUILD_TUPLE(i20, i21, i23)
  i25 = OPAQUE(i18, i24, CONSTANT({}))
  i26 = BUILD_TUPLE(i14, i15, i25)
  i27 = OPAQUE(CONSTANT([Symbol name=linear]), i26, CONSTANT({}))
  i28 = BUILD_TUPLE(i10, i27)
  i29 = OPAQUE(CONSTANT([Symbol name=topk]), i28, CONSTANT({}))
  i30 = BINARY_SUBSCR(i29, 1)
  i31 = BUILD_TUPLE('unsqueeze', i30)
  i32 = OPAQUE(i8, i31, CONSTANT({}))
  i33 = LOAD_ATTR(i32, 'func')
  i34 = BUILD_TUPLE(i21, i30)
  i35 = OPAQUE(i33, i34, CONSTANT({}))
  i36 = Instruction(opname='COMPARE_OP', opcode=107, arg=2, argval='==', argrepr='==', offset=104, starts_line=None, is_jump_target=False)(i9, i35)
  i37 = BUILD_TUPLE('permute', i36)
  i38 = OPAQUE(i8, i37, CONSTANT({}))
  i39 = LOAD_ATTR(i38, 'func')
  i40 = BINARY_SUBSCR(i19, 3)
  i41 = BUILD_TUPLE(i40, i20, i21, i36)
  i42 = OPAQUE(i39, i41, CONSTANT({}))
  i43 = LOAD_ATTR(i1, 'forward')
  i44 = LOAD_ATTR(i43, '__func__')
  i45 = LOAD_ATTR(i44, '__globals__')
  i46 = BINARY_SUBSCR(i45, '__builtins__')
  i47 = LOAD_ATTR(i46, 'zip')
  i48 = BUILD_TUPLE(i4, i42, i47)
  i49 = OPAQUE(CONSTANT(<built-in method __new__ of type object at 0x55c1c13de340>), i48, CONSTANT({}))
  i50 = BUILD_TUPLE(i49)
  i51 = OPAQUE(CONSTANT(<slot wrapper '__next__' of 'zip' objects>), i50, CONSTANT({}))
)

The above exception was the direct cause of the following exception:

NotImplementedError                       Traceback (most recent call last)
Cell In[1], line 35
     32 model = thunder.jit(model)
     34 x = torch.randn(2, 3, 2)
---> 35 y = model(x)

File ~/miniforge3/envs/pytorch-cuda-dev/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/miniforge3/envs/pytorch-cuda-dev/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/dev/lightning-thunder/thunder/__init__.py:209, in ThunderModule.forward(self, *args, **kwargs)
    208 def forward(self, *args, **kwargs):
--> 209     res = self._forward_fn(*args, **kwargs)
    210     return res

File ~/dev/lightning-thunder/thunder/__init__.py:661, in jit.<locals>.fn_(*args, **kwargs)
    658 cs.last_trace_host_start = time.time_ns()
    659 cs.calls += 1
--> 661 cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
    662 cs.last_trace_host_execution_start = time.time_ns()
    664 result = cache_entry.computation_fn(*inps)

File ~/dev/lightning-thunder/thunder/__init__.py:277, in _with_cache_info_ctx.<locals>.cache_info_wrapper(*args, **kwargs)
    275 tok = _cache_info_ctx.set({})
    276 try:
--> 277     res = fn(*args, **kwargs)
    278 finally:
    279     _cache_info_ctx.reset(tok)

File ~/dev/lightning-thunder/thunder/__init__.py:538, in jit.<locals>.get_computation_and_inputs(*args, **kwargs)
    536 prologue_trc: TraceCtx
    537 computation_trc: TraceCtx
--> 538 jit_results: TraceResults = interpreter(
    539     fn, args, kwargs, record_history=record_history, sharp_edges=cd.sharp_edges
    540 )
    541 prologue_trc = jit_results.prologue_trace
    542 computation_trc = jit_results.computation_trace

File ~/dev/lightning-thunder/thunder/__init__.py:190, in _general_frontend(fn, args, kwargs, record_history, sharp_edges)
    181 def _general_frontend(
    182     fn: Callable,
    183     args: tuple[Any, ...],
   (...)
    188     sharp_edges: SHARP_EDGES_OPTIONS,
    189 ) -> TraceResults:
--> 190     return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges, record_history=record_history)

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1481, in thunder_general_jit(fn, args, kwargs, record_history, sharp_edges)
   1478 else:
   1479     epilogue_trace = None
-> 1481 pro_to_comp_proxies, pro_to_epi_proxies = unpack_inputs(
   1482     ctx, prologue_trace, pro_to_comp, pro_to_epi, args, kwargs, has_epilogue=epilogue_trace is not None
   1483 )
   1485 proxy_order = {id(p): i for i, p in enumerate(pro_to_comp_proxies)}
   1486 pro_to_comp = tuple(sorted(pro_to_comp, key=lambda v: proxy_order[id(v.proxy)]))

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1301, in unpack_inputs(ctx, prologue_trace, pro_to_comp_inps, pro_to_epi_inps, args, kwargs, has_epilogue)
   1298             pro_kwargs_proxy = output
   1300 pro_to_epi = tuple(sorted((unpack(v) for v in pro_to_epi_inps), key=lambda x: param_ordering[id(x)][1]))
-> 1301 pro_to_comp = tuple(sorted((unpack(v) for v in pro_to_comp_inps), key=lambda x: param_ordering[id(x)][1]))
   1303 with tracectx(prologue_trace):
   1304     for prim, *args in ctx._constraints:

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1301, in <genexpr>(.0)
   1298             pro_kwargs_proxy = output
   1300 pro_to_epi = tuple(sorted((unpack(v) for v in pro_to_epi_inps), key=lambda x: param_ordering[id(x)][1]))
-> 1301 pro_to_comp = tuple(sorted((unpack(v) for v in pro_to_comp_inps), key=lambda x: param_ordering[id(x)][1]))
   1303 with tracectx(prologue_trace):
   1304     for prim, *args in ctx._constraints:

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1274, in unpack_inputs.<locals>.unpack(v)
   1272         from_provenance(p.history)
   1273     except Exception as e:
-> 1274         raise NotImplementedError(f"Exception occured unpacking object from {p.history}") from e
   1276 already_unpacked[id(p)] = p
   1278 # Adds cache constraints
   1279 # TODO Consider refactoring these contraints
   1280 # TODO Constrain on rank, device, and dtype

NotImplementedError: Exception occured unpacking object from ProvenanceRecord(
  i1 = INPUT_FN()
  i2 = LOAD_ATTR(i1, '__dict__')
  i3 = BINARY_SUBSCR(i2, '_modules')
  i4 = BINARY_SUBSCR(i3, 'experts')
  i5 = INPUT_ARGS()
  i6 = BINARY_SUBSCR(i5, 0)
  i7 = LOAD_ATTR(i6, '__getattr__')
  i8 = LOAD_ATTR(i7, '__func__')
  i9 = Instruction(opname='CALL_FUNCTION_KW', opcode=141, arg=2, argval=2, argrepr='', offset=102, starts_line=None, is_jump_target=False)()
  i10 = LOAD_ATTR(i1, 'n_expert_per_token')
  i11 = BINARY_SUBSCR(i3, 'gate')
  i12 = LOAD_ATTR(i11, '__dict__')
  i13 = BINARY_SUBSCR(i12, '_parameters')
  i14 = BINARY_SUBSCR(i13, 'bias')
  i15 = BINARY_SUBSCR(i13, 'weight')
  i16 = BUILD_TUPLE('view', i6)
  i17 = OPAQUE(i8, i16, CONSTANT({}))
  i18 = LOAD_ATTR(i17, 'func')
  i19 = Instruction(opname='BINARY_ADD', opcode=23, arg=None, argval=None, argrepr='', offset=10, starts_line=None, is_jump_target=False)()
  i20 = BINARY_SUBSCR(i19, 2)
  i21 = BINARY_SUBSCR(i19, 1)
  i22 = LOAD_ATTR(i17, 'args')
  i23 = BINARY_SUBSCR(i22, 0)
  i24 = BUILD_TUPLE(i20, i21, i23)
  i25 = OPAQUE(i18, i24, CONSTANT({}))
  i26 = BUILD_TUPLE(i14, i15, i25)
  i27 = OPAQUE(CONSTANT([Symbol name=linear]), i26, CONSTANT({}))
  i28 = BUILD_TUPLE(i10, i27)
  i29 = OPAQUE(CONSTANT([Symbol name=topk]), i28, CONSTANT({}))
  i30 = BINARY_SUBSCR(i29, 1)
  i31 = BUILD_TUPLE('unsqueeze', i30)
  i32 = OPAQUE(i8, i31, CONSTANT({}))
  i33 = LOAD_ATTR(i32, 'func')
  i34 = BUILD_TUPLE(i21, i30)
  i35 = OPAQUE(i33, i34, CONSTANT({}))
  i36 = Instruction(opname='COMPARE_OP', opcode=107, arg=2, argval='==', argrepr='==', offset=104, starts_line=None, is_jump_target=False)(i9, i35)
  i37 = BUILD_TUPLE('permute', i36)
  i38 = OPAQUE(i8, i37, CONSTANT({}))
  i39 = LOAD_ATTR(i38, 'func')
  i40 = BINARY_SUBSCR(i19, 3)
  i41 = BUILD_TUPLE(i40, i20, i21, i36)
  i42 = OPAQUE(i39, i41, CONSTANT({}))
  i43 = LOAD_ATTR(i1, 'forward')
  i44 = LOAD_ATTR(i43, '__func__')
  i45 = LOAD_ATTR(i44, '__globals__')
  i46 = BINARY_SUBSCR(i45, '__builtins__')
  i47 = LOAD_ATTR(i46, 'zip')
  i48 = BUILD_TUPLE(i4, i42, i47)
  i49 = OPAQUE(CONSTANT(<built-in method __new__ of type object at 0x55c1c13de340>), i48, CONSTANT({}))
  i50 = BUILD_TUPLE(i49)
  i51 = OPAQUE(CONSTANT(<slot wrapper '__next__' of 'zip' objects>), i50, CONSTANT({}))
  i52 = BINARY_SUBSCR(i51, 1)
  i53 = LOAD_ATTR(i52, '__dict__')
  i54 = BINARY_SUBSCR(i53, '_parameters')
  i55 = BINARY_SUBSCR(i54, 'weight')
)
t-vi commented 4 months ago

Thank you @IvanYashchuk The underlying issue is "need lookaside for zip in interpreter".

t-vi commented 4 months ago

Seems that this is a good issue for someone who wants to take a look at our great Python interpreter (thunder/core/interpreter.py), it's not trivial but should be relatively self-contained.

lantiga commented 4 months ago

@t-vi what would be a similar lookaside to start from for anyone wanting to approach this?

nikitaved commented 4 months ago

I would I assume that functools.reduce is not that bad to look at, specifically because of the test coverage (generic iterables, custom iterables, etc.).

IvanYashchuk commented 4 months ago

Doesn't the error message NotImplementedError: unpacking from OPAQUE <slot wrapper '__next__' of 'zip' objects> mean that the problem is in the next interpretation and not in zip?

riccardofelluga commented 4 months ago

@t-vi to me it looks like there some errors in the unpacking more than the zip, might it be the opaque ModuleList container?

Nachiket18 commented 1 month ago

I would like to participate in the solution for this issue. I have some knowledge of ast library and worked on clang static analyzer before.

t-vi commented 1 month ago

Hi @Nachiket18 ,

great! So the main issue here seems to be that we would want the Thunder Python Interpreter to be able to "see through" zip (i.e. link the things that are yielded by the zip iteration with the arguments to zip). I think it should be as easy as implementing a zip function in pure Python and putting it in a lookaside. If you look at def _enumerate_lookaside(obj: Iterable, start: int = 0): in thunder/core/interpreter.py, that would give you a good idea of how to do it. Testing should be added to thunder/tests/test_interpreter.py for the new zip to give the same as the old. The repro @IvanYashchuk has in the issue could go in test_jit_general.py.

Please let me know if there is anything else you need to get started. Also don't hesitate to reach out if you find you're stuck somewhere.