Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.15k stars 77 forks source link

thunder.torch.randn should accept any Sequence of integers as a shape argument #604

Closed IvanYashchuk closed 3 weeks ago

IvanYashchuk commented 3 months ago

🐛 Bug

import thunder
import torch

@thunder.jit
def uniform(shape, device):
    return torch.randn(shape, device=device)

shape = [1, 4, 40, 84, 84]
t1 = uniform(shape, device="cuda:0")
print(t1)

traceback:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[1], line 13
     10     return torch.randn(shape, device=device)
     12 shape = [1, 4, 40, 84, 84]
---> 13 t1 = uniform(shape, device="cuda:0")
     14 print(t1)

File ~/dev/lightning-thunder/thunder/__init__.py:660, in jit.<locals>.fn_(*args, **kwargs)
    657 cs.last_trace_host_start = time.time_ns()
    658 cs.calls += 1
--> 660 cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
    661 cs.last_trace_host_execution_start = time.time_ns()
    663 result = cache_entry.computation_fn(*inps)

File ~/dev/lightning-thunder/thunder/__init__.py:217, in _with_cache_info_ctx.<locals>.cache_info_wrapper(*args, **kwargs)
    215 tok = _cache_info_ctx.set({})
    216 try:
--> 217     res = fn(*args, **kwargs)
    218 finally:
    219     _cache_info_ctx.reset(tok)

File ~/dev/lightning-thunder/thunder/__init__.py:496, in jit.<locals>.get_computation_and_inputs(*args, **kwargs)
    494 prologue_trc: TraceCtx
    495 computation_trc: TraceCtx
--> 496 jit_results: TraceResults = interpreter(
    497     fn, args, kwargs, record_history=record_history, sharp_edges=cd.sharp_edges
    498 )
    499 prologue_trc = jit_results.prologue_trace
    500 computation_trc = jit_results.computation_trace

File ~/dev/lightning-thunder/thunder/__init__.py:205, in _general_frontend(fn, args, kwargs, record_history, sharp_edges)
    196 def _general_frontend(
    197     fn: Callable,
    198     args: tuple[Any, ...],
   (...)
    203     sharp_edges: SHARP_EDGES_OPTIONS,
    204 ) -> TraceResults:
--> 205     return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges, record_history=record_history)

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1605, in thunder_general_jit(fn, args, kwargs, record_history, sharp_edges)
   1603 with general_jit_ctx(ctx):
   1604     with tracectx(computation_trace):
-> 1605         result = jfn(*args, **kwargs)
   1606         prims.python_return(result)
   1607         computation_trace.set_current_source_location(None, None)

File ~/dev/lightning-thunder/thunder/core/interpreter.py:6696, in interpret.<locals>.fn_(*args, **kwargs)
   6694     assert isinstance(e, BaseException), e
   6695     runtimectx.curexc = None
-> 6696     raise e
   6698 return interpretation_result

File ~/dev/lightning-thunder/thunder/core/interpreter.py:6664, in interpret.<locals>.fn_.<locals>.getfn.<locals>.fn_2()
   6663 def fn_2(args, kwargs):
-> 6664     return fn(*args, **kwargs)

Cell In[1], line 10, in uniform()
      8 @thunder.jit
      9 def uniform(shape, device):
---> 10     return torch.randn(shape, device=device)

File ~/dev/lightning-thunder/thunder/core/interpreter.py:1273, in interpreter_needs_wrap.<locals>.wrapping_wrapper(*args, **kwargs)
   1270     ukwargs = kwargs
   1272 try:
-> 1273     res = ufn(*uargs, **ukwargs)
   1275     # If result is a WrappedValue, we trust its provenance record
   1276     if isinstance(res, WrappedValue):

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:700, in record_source_loc_in_symbol_header.<locals>.wrapper(*args, **kwargs)
    698 ctx: GeneralJitCtx = get_general_jit_ctx()
    699 ctx._computation_trace.set_current_source_location(filename, positions)
--> 700 return fn(*args, **kwargs)

File ~/dev/lightning-thunder/thunder/core/symbol.py:268, in Symbol.__call__(self, *args, **kwargs)
    266 else:
    267     trace.push_scope(subsymbols)
--> 268     result = self.meta(*args, **kwargs)
    269     trace.pop_scope()
    271 bsym = self.bind(*args, **kwargs, output=result, subsymbols=subsymbols)

File ~/dev/lightning-thunder/thunder/core/langctxs.py:132, in langctx.__call__.<locals>._fn(*args, **kwargs)
    130 try:
    131     tok = set_langctx(self.langctx)
--> 132     result = fn(*args, **kwargs)
    133     return result
    134 finally:

File ~/dev/lightning-thunder/thunder/torch/__init__.py:654, in randn(generator, dtype, device, layout, requires_grad, pin_memory, out, *shape)
    652 dtype = to_dtype(dtype)
    653 shape = utils.extract_shape_from_varargs(shape)
--> 654 return prims.randn(shape, device=device, dtype=dtype)

File ~/dev/lightning-thunder/thunder/core/symbol.py:264, in Symbol.__call__(self, *args, **kwargs)
    261         return self.meta(*args, **kwargs)
    263     trace.push_scope(None)  # BUG: This is wrong, push_scope only accepts lists. What should this be instead?
--> 264     result = self.meta(*args, **kwargs)
    265     trace.pop_scope()
    266 else:

File ~/dev/lightning-thunder/thunder/core/langctxs.py:132, in langctx.__call__.<locals>._fn(*args, **kwargs)
    130 try:
    131     tok = set_langctx(self.langctx)
--> 132     result = fn(*args, **kwargs)
    133     return result
    134 finally:

File ~/dev/lightning-thunder/thunder/core/prims.py:2773, in _randn_meta(shape, device, dtype)
   2771 utils.check_type(device, devices.Device)
   2772 utils.check_type(dtype, dtypes.dtype)
-> 2773 utils.check_type(shape, tuple)
   2774 utils.check_valid_shape(shape)
   2775 return TensorProxy(shape=shape, device=device, dtype=dtype, requires_grad=False)

File ~/dev/lightning-thunder/thunder/core/baseutils.py:107, in check_type(x, types)
    106 def check_type(x: Any, types: type | Sequence[type]):
--> 107     check(
    108         isinstance(x, types),
    109         lambda: f"{x} had an unexpected type {type(x)}. Supported types are {types}",
    110         exception_type=ValueError,
    111     )

File ~/dev/lightning-thunder/thunder/core/baseutils.py:103, in check(cond, s, exception_type)
     98 """Helper function for raising an error_type (default: RuntimeError) if a boolean condition fails.
     99 
    100 s is a callable producing a string to avoid string construction if the error check is passed.
    101 """
    102 if not cond:
--> 103     raise exception_type(s())

ValueError: [1, 4, 40, 84, 84] had an unexpected type <class 'list'>. Supported types are <class 'tuple'>

cc @apaz-cli