Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.19k stars 79 forks source link

Support `torch.where(condition)` with thunder.jit #124

Open carmocca opened 7 months ago

carmocca commented 7 months ago

🚀 Feature

Motivation

Mixtral uses it: https://github.com/mistralai/mistral-src/blob/8598cf582091a596671be31990448e0620017851/moe_one_file_ref.py#L215

Minimal Repro

import thunder

def fn(cond):
    return torch.where(cond)

thunder.jit(fn)(torch.randn(3) > 0)

Pitch

Support image from https://pytorch.org/docs/stable/generated/torch.where.html

Additional context

We already support torch.where(condition, input, other): https://github.com/search?q=repo%3ALightning-AI%2Flightning-thunder+%22def+where%22&type=code

cc @apaz-cli

nikitaved commented 7 months ago

As of now, we cannot support data-dependent ops, alas...

nikitaved commented 7 months ago

@carmocca , looking at the code I think the solution could be modifying the model in the package. The result of topk can be sorted, and then we do not need to apply where at all. This will also eliminate the device sync (syncs, actually) caused by where.

carmocca commented 7 months ago

@nikitaved Faster and better code is very welcome in LitGPT. I benchmarked a few different implementations when this was added and this came out to be the best in general (see description and discussion in https://github.com/Lightning-AI/litgpt/pull/823). It would be useful to see them compared to whatever you propose.

IvanYashchuk commented 7 months ago

The error message is not friendly and doesn't tell that torch.where(condition) is not supported properly:

In [1]: import torch

In [2]: import thunder

In [3]: from litgpt import Config

In [4]: from litgpt.model import LLaMAMoE

In [5]: config = Config.from_name("Mixtral-8x7B-v0.1")

In [6]: model = LLaMAMoE(config).to(dtype=torch.bfloat16, device="cuda")

In [7]: jit_model = thunder.jit(model)

In [8]: x = torch.randn(2, config.block_size, config.n_embd, dtype=torch.bfloat16, device="cuda")

In [9]: jit_model(x);

Traceback:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[9], line 1
----> 1 jit_model(x);

File ~/dev/pytorch/main/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/dev/pytorch/main/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/dev/lightning-thunder/thunder/__init__.py:194, in ThunderModule.forward(self, *args, **kwargs)
    193 def forward(self, *args, **kwargs):
--> 194     res = self._forward_fn(*args, **kwargs)
    195     return res

File ~/dev/lightning-thunder/thunder/__init__.py:629, in jit.<locals>.fn_(*args, **kwargs)
    626 cs.last_trace_host_start = time.time_ns()
    627 cs.calls += 1
--> 629 cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
    630 cs.last_trace_host_execution_start = time.time_ns()
    632 result = cache_entry.computation_fn(*inps)

File ~/dev/lightning-thunder/thunder/__init__.py:262, in _with_cache_info_ctx.<locals>.cache_info_wrapper(*args, **kwargs)
    260 tok = _cache_info_ctx.set({})
    261 try:
--> 262     res = fn(*args, **kwargs)
    263 finally:
    264     _cache_info_ctx.reset(tok)

File ~/dev/lightning-thunder/thunder/__init__.py:504, in jit.<locals>.get_computation_and_inputs(*args, **kwargs)
    502     prologue_trc: TraceCtx
    503     computation_trc: TraceCtx
--> 504     prologue_trc, computation_trc, *maybe_epilogue = interpreter(
    505         fn, args, kwargs, sharp_edges=cd.sharp_edges
    506     )
    508 if maybe_epilogue:
    509     epilogue_traces = maybe_epilogue

File ~/dev/lightning-thunder/thunder/__init__.py:175, in _general_frontend(fn, args, kwargs, sharp_edges)
    174 def _general_frontend(fn: Callable, args, kwargs, /, *, sharp_edges: SHARP_EDGES_OPTIONS) -> tuple[TraceCtx, TraceCtx]:
--> 175     return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges)

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1440, in thunder_general_jit(fn, args, kwargs, sharp_edges)
   1438 with general_jit_ctx(ctx):
   1439     with tracectx(computation_trace):
-> 1440         result = jfn(*args, **kwargs)
   1441         prims.python_return(result)
   1442         process_recorded_modifications(ctx, epilogue_trace)

File ~/dev/lightning-thunder/thunder/core/interpreter.py:6684, in interpret.<locals>.fn_(*args, **kwargs)
   6682     assert isinstance(e, BaseException), e
   6683     runtimectx.curexc = None
-> 6684     raise e
   6686 return interpretation_result

File ~/dev/lightning-thunder/thunder/core/interpreter.py:6647, in interpret.<locals>.fn_.<locals>.getfn.<locals>.fn_2()
   6646 def fn_2(args, kwargs):
-> 6647     return fn(*args, **kwargs)

File ~/dev/lightning-thunder/thunder/core/interpreter.py:6046, in _call_dispatch.<locals>._impl()
   6045 def _impl(fn, *args, **kwargs):
-> 6046     return fn.__func__(fn.__self__, *args, **kwargs)

File ~/dev/pytorch/main/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl()
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/dev/lightning-thunder/thunder/core/interpreter.py:6046, in _call_dispatch.<locals>._impl()
   6045 def _impl(fn, *args, **kwargs):
-> 6046     return fn.__func__(fn.__self__, *args, **kwargs)

File ~/dev/pytorch/main/torch/nn/modules/module.py:1520, in Module._call_impl()
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/dev/lightning-thunder/thunder/core/interpreter.py:6046, in _call_dispatch.<locals>._impl()
   6045 def _impl(fn, *args, **kwargs):
-> 6046     return fn.__func__(fn.__self__, *args, **kwargs)

File ~/dev/litgpt/litgpt/model.py:347, in LLaMAMoE.forward()
    345 y = torch.zeros_like(x)  # (B*T, C)
    346 for mask, expert in zip(masks, self.experts):
--> 347     token_idx, expert_idx = torch.where(mask)
    348     y[token_idx] += probs[token_idx, expert_idx, None] * expert(x[token_idx])
    349 return y.view(B, T, C)

File ~/dev/lightning-thunder/thunder/core/interpreter.py:1258, in interpreter_needs_wrap.<locals>.wrapping_wrapper(*args, **kwargs)
   1255     ukwargs = kwargs
   1257 try:
-> 1258     res = ufn(*uargs, **ukwargs)
   1260     # If result is a WrappedValue, we trust its provenance record
   1261     if isinstance(res, WrappedValue):

File ~/dev/lightning-thunder/thunder/core/symbol.py:250, in Symbol.__call__(self, *args, **kwargs)
    248 else:
    249     trace.push_scope(subsymbols)
--> 250     result = self.meta(*args, **kwargs)
    251     trace.pop_scope()
    253 bsym = self.bind(*args, **kwargs, output=result, subsymbols=subsymbols)

File ~/dev/lightning-thunder/thunder/core/langctxs.py:124, in langctx.__call__.<locals>._fn(*args, **kwargs)
    122 try:
    123     tok = set_langctx(self.langctx)
--> 124     result = fn(*args, **kwargs)
    125     return result
    126 finally:

TypeError: where() missing 2 required positional arguments: 'a' and 'b'
nikitaved commented 7 months ago

@IvanYashchuk , looks like we should update the meta function for where. To be frank, I did not even know about this overload...

Might be a very nice issue for external contributors...

mruberry commented 6 months ago

triage review:

carmocca commented 6 months ago

nonzero doesn't have a shape= argument. Did you mean as_tuple=?

mruberry commented 6 months ago

nonzero doesn't have a shape= argument. Did you mean as_tuple=?

We were referring to a parameter that would be analogous to jax.lax.nonzero's size parameter:

https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.nonzero.html

IvanYashchuk commented 6 months ago

This prototype adds torch.where(boolean_tensor) https://github.com/Lightning-AI/lightning-thunder/pull/303

mruberry commented 1 week ago

@IvanYashchuk @kshitij12345 Is this still unsupported when using ThunderFX? If torch.where(condition) is supported when using ThunderFX (because the operator is sent to PyTorch for execution?), then maybe we can close or amend this issue to refer more specifically to using torch.where(condition) with the Thunder interpreter as the entrypoint?

kshitij12345 commented 1 week ago

torch.where(condition) works with ThunderFX path by sending it to PyTorch. We also have a test for the same.

https://github.com/Lightning-AI/lightning-thunder/blob/b28d5b3536e60fb0b30896bdd4df6e288cf6a5c8/thunder/tests/test_dynamo.py#L346-L349

Will update the issue title to reflect the request for torch.where(condition) not being supported by thunder.jit entrypoint.