shawntan / scattermoe

Triton-based implementation of Sparse Mixture of Experts.
Apache License 2.0
150 stars 10 forks source link

Can't use torch.compile #12

Open shikhartuli opened 2 weeks ago

shikhartuli commented 2 weeks ago

When I compile the model, I get the following error. Any idea how to fix this?

Traceback (most recent call last):
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_18849/856931662.py", line 14, in <module>
    model(input_ids=tokens)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/group-volume/users/shikhar.tuli/hydra/transformers/src/transformers/models/hydra/modeling_hydra.py", line 1905, in forward
    outputs = self.model(
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/group-volume/users/shikhar.tuli/hydra/transformers/src/transformers/models/hydra/modeling_hydra.py", line 1693, in forward
    logger.warning_once(
  File "/group-volume/users/shikhar.tuli/hydra/transformers/src/transformers/models/hydra/modeling_hydra.py", line 1703, in resume_in_forward
    causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
  File "/group-volume/users/shikhar.tuli/hydra/transformers/src/transformers/models/hydra/modeling_hydra.py", line 1726, in resume_in_forward
    layer_outputs = decoder_layer(
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/group-volume/users/shikhar.tuli/hydra/transformers/src/transformers/models/hydra/modeling_hydra.py", line 1392, in forward
    hidden_states, router_logits = self.hydra(
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/group-volume/users/shikhar.tuli/hydra/transformers/src/transformers/models/hydra/modeling_hydra.py", line 1221, in forward
    return self.cuda_kernels_forward(hidden_states, cache_params)
  File "/group-volume/users/shikhar.tuli/hydra/transformers/src/transformers/models/hydra/modeling_hydra.py", line 1025, in cuda_kernels_forward
    projected_states, _, in_proj_router_logits = self.in_proj(hidden_states)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/group-volume/users/shikhar.tuli/hydra/transformers/src/transformers/models/hydra/modeling_hydra.py", line 892, in forward
    final_hidden_states = self.experts(hidden_states, routing_weights, selected_experts).to(hidden_states.dtype)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/group-volume/users/shikhar.tuli/hydra/scattermoe/scattermoe/mlp.py", line 124, in forward
    padded_block_idxs, expert_offsets = kernels.ops.padded_block_indices(sorted_expert_idxs, self.num_experts)
  File "/group-volume/users/shikhar.tuli/hydra/scattermoe/scattermoe/mlp.py", line 126, in resume_in_forward
    h = self.experts(
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/group-volume/users/shikhar.tuli/hydra/scattermoe/scattermoe/parallel_experts.py", line 143, in forward
    results = ParallelLinear.apply(
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/group-volume/users/shikhar.tuli/hydra/scattermoe/scattermoe/parallel_experts.py", line 14, in forward
    output = kernels.ops.scatter2scatter(
  File "/group-volume/users/shikhar.tuli/hydra/scattermoe/scattermoe/kernels/ops.py", line 139, in scatter2scatter
    with torch.cuda.device(X.device):
  File "/group-volume/users/shikhar.tuli/hydra/scattermoe/scattermoe/kernels/ops.py", line 140, in resume_in_scatter2scatter
    _scatter2scatter[grid](
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 127, in run
    self.nargs = dict(zip(self.arg_names, args))
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 127, in resume_in_run
    self.nargs = dict(zip(self.arg_names, args))
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 127, in resume_in_run
    self.nargs = dict(zip(self.arg_names, args))
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 128, in resume_in_run
    if len(self.configs) > 1:
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 151, in resume_in_run
    config = self.configs[0]
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 152, in resume_in_run
    self.best_config = config
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 156, in resume_in_run
    ret = self.fn.run(
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 655, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 727, in _convert_frame
    result = inner_convert(frame, cache_entry, hooks, frame_state)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 383, in _convert_frame_assert
    compiled_product = _compile(
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 665, in _compile
    raise InternalTorchDynamoError(str(e)).with_traceback(
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 646, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 562, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
    transformations(instructions, code_options)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 151, in _fn
    return fn(*args, **kwargs)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 527, in transform
    tracer.run()
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2128, in run
    super().run()
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
    return inner_fn(self, inst)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1252, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars.items)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 660, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 775, in call_method
    return self.clone(grid=grid).call_function(tx, args, kwargs)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 743, in call_function
    "kwargs": meta.as_proxy(),
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/variables/dicts.py", line 33, in as_proxy
    return {k: v.as_proxy() for k, v in self.items.items()}
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/variables/dicts.py", line 33, in <dictcomp>
    return {k: v.as_proxy() for k, v in self.items.items()}
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/variables/lazy.py", line 90, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/torch/_dynamo/variables/base.py", line 274, in as_proxy
    raise NotImplementedError(str(self))
torch._dynamo.exc.InternalTorchDynamoError: UserDefinedObjectVariable(dtype)

from user code:
   File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 305, in run
    return self.fn.run(*args, **kwargs)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 2168, in showtraceback
    stb = self.InteractiveTB.structured_traceback(
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/IPython/core/ultratb.py", line 1454, in structured_traceback
    return FormattedTB.structured_traceback(
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/IPython/core/ultratb.py", line 1345, in structured_traceback
    return VerboseTB.structured_traceback(
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/IPython/core/ultratb.py", line 1192, in structured_traceback
    formatted_exception = self.format_exception_as_a_whole(etype, evalue, etb, number_of_lines_of_context,
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/IPython/core/ultratb.py", line 1082, in format_exception_as_a_whole
    self.get_records(etb, number_of_lines_of_context, tb_offset) if etb else []
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/IPython/core/ultratb.py", line 1179, in get_records
    res = list(stack_data.FrameInfo.stack_data(etb, options=options))[tb_offset:]
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/stack_data/core.py", line 597, in stack_data
    yield from collapse_repeated(
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/stack_data/utils.py", line 77, in collapse_repeated
    for is_highlighted, group in itertools.groupby(
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/site-packages/stack_data/utils.py", line 45, in highlight_unique
    counts = Counter(lst)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/collections/__init__.py", line 577, in __init__
    self.update(iterable, **kwds)
  File "/home/user/anaconda3/envs/hydra/lib/python3.10/collections/__init__.py", line 670, in update
    _count_elements(self, iterable)
TypeError: unhashable type: 'dict'