huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.87k stars 26.98k forks source link

Error in all_reduce when GPT2 200B inferencing with dynamo and multi GPU #27991

Closed jcai04 closed 9 months ago

jcai04 commented 11 months ago

System Info

version:

Who can help?

@SunMarc

Information

Tasks

Reproduction

Code snippet:

class GPT2Block(nn.Module):

    def __init__(self, config, window_size):
        super().__init__()
        self.mp_size = iint(os.getenv("WORLD_SIZE", "1"))
        self.hidden_size = config.hidden_size
        self.ln_1 = nn.LayerNorm(self.hidden_size, eps=1e-5)
        self.attn = GPT2Attention(config, window_size)
        self.mlp = GPT2MLP(config)

    def forward(self, hidden_states, attention_mask, past_kv, kv_len, wpe=None):
        residual = hidden_states

        hidden_states = self.ln_1(hidden_states)
        attn_output, _ = self.attn(hidden_states, attention_mask, past_kv, kv_len, wpe)
        mlp_output = self.mlp(hidden_states)

        layer_out = attn_output + mlp_output

        if self.mp_size > 1:
            torch.distributed.all_reduce(layer_out)

        layer_out = layer_out + residual
        return layer_out

Error messages:

Traceback (most recent call last):
  File"./ut_test/seperate_200b.py", line 393, in <module>
    out_v2 = inference_engine(inputs)
  File "./ut_test/seperate_200b.py", line 250, in inference_engine
    context_output = context_infer(BI_model, curr_input)
  File "./ut_test/seperate_200b.py", line 199, in context_infer
    outputs = model(**one_input)
  File "/home/gpt2_200b/models/gpt2_200b_ptb.py", line 799 in forward
    hidden_states = self.transformer(input_tensor, input_mask, past_key, past_key_values, kv_len, query_len)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 490, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 389, in_convert_frame_assert
    return _compile(
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 569, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 491, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/bytecode_transformation.py" line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 458, in transform
    tracer.run()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 2074, in run
    super().run()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 1115, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 562, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/variables/nn_module.py", line 331, in call_function
    return tx.inline_user_function_return(
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 2179, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 2286, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 1155, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars.items)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 562, in call_function
    self.push(fn.call_function(self, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/variables/functions.py", line 307, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/variables/functions.py", line 261, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 2179, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 2286, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 1155, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars.items)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 562, in call_function
    self.push(fn.call_function(self, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/variables/functions.py", line 307, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/variables/functions.py", line 261, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 2179, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 2232, in inline_call
    InliningInstructionTranslator.check_inlineable(func)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 2191, in check_inlineable
    unimplemented(f"inlining disallowed: {func.get_function()}")
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/exc.py", line 172, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: inlining disallowed: <function all_reduce at 0x7fab7fff78b0>

from user code:
  File "/home/gpt2_200b/models/gpt2_200b_ptb.py", line 508, in forward
    hidden_states = block(hidden_states, attention_mask, past_kv[idx], kv_len, self.wpe)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/gpt2_200b/models/gpt2_200b_ptb.py", line 460, in forward
    torch.distributed.all_reduce(layer_out)

torch.compile setting:

torch.compile(self.transformer, dynamic=True, fullgraph=True)  #default backend = inductor

Expected behavior

We except to be able to do inference with dynamo, and we successfully inference when setting "fullgraph=False" in torch.compile. However, it is doesn't work when "fullgraph=True" in torch.compile with the same code

SunMarc commented 11 months ago

Hi @jcai04, this does not seems to be an issue with transformers. Please submit this issue in the pytorch repo.

github-actions[bot] commented 10 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.