Open carmocca opened 7 months ago
As of now, we cannot support data-dependent ops, alas...
@carmocca , looking at the code I think the solution could be modifying the model in the package. The result of topk
can be sorted, and then we do not need to apply where
at all. This will also eliminate the device sync (syncs, actually) caused by where
.
@nikitaved Faster and better code is very welcome in LitGPT. I benchmarked a few different implementations when this was added and this came out to be the best in general (see description and discussion in https://github.com/Lightning-AI/litgpt/pull/823). It would be useful to see them compared to whatever you propose.
The error message is not friendly and doesn't tell that torch.where(condition)
is not supported properly:
In [1]: import torch
In [2]: import thunder
In [3]: from litgpt import Config
In [4]: from litgpt.model import LLaMAMoE
In [5]: config = Config.from_name("Mixtral-8x7B-v0.1")
In [6]: model = LLaMAMoE(config).to(dtype=torch.bfloat16, device="cuda")
In [7]: jit_model = thunder.jit(model)
In [8]: x = torch.randn(2, config.block_size, config.n_embd, dtype=torch.bfloat16, device="cuda")
In [9]: jit_model(x);
Traceback:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[9], line 1
----> 1 jit_model(x);
File ~/dev/pytorch/main/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File ~/dev/pytorch/main/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None
File ~/dev/lightning-thunder/thunder/__init__.py:194, in ThunderModule.forward(self, *args, **kwargs)
193 def forward(self, *args, **kwargs):
--> 194 res = self._forward_fn(*args, **kwargs)
195 return res
File ~/dev/lightning-thunder/thunder/__init__.py:629, in jit.<locals>.fn_(*args, **kwargs)
626 cs.last_trace_host_start = time.time_ns()
627 cs.calls += 1
--> 629 cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
630 cs.last_trace_host_execution_start = time.time_ns()
632 result = cache_entry.computation_fn(*inps)
File ~/dev/lightning-thunder/thunder/__init__.py:262, in _with_cache_info_ctx.<locals>.cache_info_wrapper(*args, **kwargs)
260 tok = _cache_info_ctx.set({})
261 try:
--> 262 res = fn(*args, **kwargs)
263 finally:
264 _cache_info_ctx.reset(tok)
File ~/dev/lightning-thunder/thunder/__init__.py:504, in jit.<locals>.get_computation_and_inputs(*args, **kwargs)
502 prologue_trc: TraceCtx
503 computation_trc: TraceCtx
--> 504 prologue_trc, computation_trc, *maybe_epilogue = interpreter(
505 fn, args, kwargs, sharp_edges=cd.sharp_edges
506 )
508 if maybe_epilogue:
509 epilogue_traces = maybe_epilogue
File ~/dev/lightning-thunder/thunder/__init__.py:175, in _general_frontend(fn, args, kwargs, sharp_edges)
174 def _general_frontend(fn: Callable, args, kwargs, /, *, sharp_edges: SHARP_EDGES_OPTIONS) -> tuple[TraceCtx, TraceCtx]:
--> 175 return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges)
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1440, in thunder_general_jit(fn, args, kwargs, sharp_edges)
1438 with general_jit_ctx(ctx):
1439 with tracectx(computation_trace):
-> 1440 result = jfn(*args, **kwargs)
1441 prims.python_return(result)
1442 process_recorded_modifications(ctx, epilogue_trace)
File ~/dev/lightning-thunder/thunder/core/interpreter.py:6684, in interpret.<locals>.fn_(*args, **kwargs)
6682 assert isinstance(e, BaseException), e
6683 runtimectx.curexc = None
-> 6684 raise e
6686 return interpretation_result
File ~/dev/lightning-thunder/thunder/core/interpreter.py:6647, in interpret.<locals>.fn_.<locals>.getfn.<locals>.fn_2()
6646 def fn_2(args, kwargs):
-> 6647 return fn(*args, **kwargs)
File ~/dev/lightning-thunder/thunder/core/interpreter.py:6046, in _call_dispatch.<locals>._impl()
6045 def _impl(fn, *args, **kwargs):
-> 6046 return fn.__func__(fn.__self__, *args, **kwargs)
File ~/dev/pytorch/main/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl()
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File ~/dev/lightning-thunder/thunder/core/interpreter.py:6046, in _call_dispatch.<locals>._impl()
6045 def _impl(fn, *args, **kwargs):
-> 6046 return fn.__func__(fn.__self__, *args, **kwargs)
File ~/dev/pytorch/main/torch/nn/modules/module.py:1520, in Module._call_impl()
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None
File ~/dev/lightning-thunder/thunder/core/interpreter.py:6046, in _call_dispatch.<locals>._impl()
6045 def _impl(fn, *args, **kwargs):
-> 6046 return fn.__func__(fn.__self__, *args, **kwargs)
File ~/dev/litgpt/litgpt/model.py:347, in LLaMAMoE.forward()
345 y = torch.zeros_like(x) # (B*T, C)
346 for mask, expert in zip(masks, self.experts):
--> 347 token_idx, expert_idx = torch.where(mask)
348 y[token_idx] += probs[token_idx, expert_idx, None] * expert(x[token_idx])
349 return y.view(B, T, C)
File ~/dev/lightning-thunder/thunder/core/interpreter.py:1258, in interpreter_needs_wrap.<locals>.wrapping_wrapper(*args, **kwargs)
1255 ukwargs = kwargs
1257 try:
-> 1258 res = ufn(*uargs, **ukwargs)
1260 # If result is a WrappedValue, we trust its provenance record
1261 if isinstance(res, WrappedValue):
File ~/dev/lightning-thunder/thunder/core/symbol.py:250, in Symbol.__call__(self, *args, **kwargs)
248 else:
249 trace.push_scope(subsymbols)
--> 250 result = self.meta(*args, **kwargs)
251 trace.pop_scope()
253 bsym = self.bind(*args, **kwargs, output=result, subsymbols=subsymbols)
File ~/dev/lightning-thunder/thunder/core/langctxs.py:124, in langctx.__call__.<locals>._fn(*args, **kwargs)
122 try:
123 tok = set_langctx(self.langctx)
--> 124 result = fn(*args, **kwargs)
125 return result
126 finally:
TypeError: where() missing 2 required positional arguments: 'a' and 'b'
@IvanYashchuk , looks like we should update the meta function for where
. To be frank, I did not even know about this overload...
Might be a very nice issue for external contributors...
triage review:
torch.where(condition)
in mixtral use the hypothetical shape
parameter to nonzero to make the output shape known at compile-time?nonzero(..., shape=...)
nonzero
doesn't have a shape=
argument. Did you mean as_tuple=
?
nonzero
doesn't have ashape=
argument. Did you meanas_tuple=
?
We were referring to a parameter that would be analogous to jax.lax.nonzero
's size
parameter:
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.nonzero.html
This prototype adds torch.where(boolean_tensor)
https://github.com/Lightning-AI/lightning-thunder/pull/303
@IvanYashchuk @kshitij12345 Is this still unsupported when using ThunderFX? If torch.where(condition)
is supported when using ThunderFX (because the operator is sent to PyTorch for execution?), then maybe we can close or amend this issue to refer more specifically to using torch.where(condition)
with the Thunder interpreter as the entrypoint?
torch.where(condition)
works with ThunderFX path by sending it to PyTorch. We also have a test for the same.
Will update the issue title to reflect the request for torch.where(condition)
not being supported by thunder.jit entrypoint.
🚀 Feature
Motivation
Mixtral uses it: https://github.com/mistralai/mistral-src/blob/8598cf582091a596671be31990448e0620017851/moe_one_file_ref.py#L215
Minimal Repro
Pitch
Support from https://pytorch.org/docs/stable/generated/torch.where.html
Additional context
We already support
torch.where(condition, input, other)
: https://github.com/search?q=repo%3ALightning-AI%2Flightning-thunder+%22def+where%22&type=codecc @apaz-cli