microsoft / microxcaling

PyTorch emulation library for Microscaling (MX)-compatible data formats
MIT License
123 stars 14 forks source link

Example ffn hangs on multi-gpu scenario (works on single gpu) #10

Closed lessw2020 closed 7 months ago

lessw2020 commented 7 months ago

This appears to be the same issue at the core as the hanging unit tests, but for simplicity will open sep issue.

Repro steps: update example/ffn_mx.py with sys.path.append('..') to ensure access to mx module. bash run_mx6.sh

On single gpu - result is "Done!". On multi-gpu - hangs.

Stack trace from hang:

(pytorch) ubuntu@ip-172-31-66-198:~/microxcaling/examples$ bash run_mx6.sh --full-trace
^CTraceback (most recent call last):
  File "/home/ubuntu/microxcaling/examples/ffn_mx.py", line 74, in <module>
    y = mlp(x)
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1519, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/microxcaling/examples/ffn_mx.py", line 39, in forward
    norm_outputs = self.layernorm(inputs)
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1519, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/microxcaling/examples/../mx/layernorm.py", line 91, in forward
    return LayerNormFunction.apply(
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/torch/autograd/function.py", line 551, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/ubuntu/microxcaling/examples/../mx/layernorm.py", line 22, in forward
    x = vec_quantize(x, mx_specs=mx_specs)
  File "/home/ubuntu/microxcaling/examples/../mx/vector_ops.py", line 35, in vec_quantize
    return quantize_elemwise_op(input, mx_specs=mx_specs,
  File "/home/ubuntu/microxcaling/examples/../mx/elemwise_ops.py", line 253, in quantize_elemwise_op
    A = _quantize_bfloat(A, bfloat=mx_specs['bfloat'], round=round,
  File "/home/ubuntu/microxcaling/examples/../mx/elemwise_ops.py", line 206, in _quantize_bfloat
    return _quantize_elemwise_core(
  File "/home/ubuntu/microxcaling/examples/../mx/elemwise_ops.py", line 118, in _quantize_elemwise_core
    from . import custom_extensions
  File "/home/ubuntu/microxcaling/examples/../mx/custom_extensions.py", line 19, in <module>
    funcs = load(name="custom_extensions", sources=sources)
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/torch/utils/cpp_extension.py", line 1308, in load
    return _jit_compile(
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/torch/utils/cpp_extension.py", line 1724, in _jit_compile
    baton.wait()
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/torch/utils/file_baton.py", line 42, in wait
    time.sleep(self.wait_seconds)
KeyboardInterrupt

Note that I waited over 5 minutes in case this was a long compile situation, and reproed it multiple times.

rizhao-msft commented 7 months ago

Closing as duplicate of #8