microsoft / Megatron-DeepSpeed

Ongoing research training transformer language models at scale, including: BERT & GPT-2
Other
1.89k stars 344 forks source link

[bug]: `ipex` install breaks non `xpu` devices #435

Open saforem2 opened 3 months ago

saforem2 commented 3 months ago

It looks like this line:

https://github.com/microsoft/Megatron-DeepSpeed/blob/61350c55478fba29ecf40940a629a3e7ce008a05/megatron/model/__init__.py#L4

from #431 breaks things on non Intel systems.

A simple (not yet tested) fix for a device-angostic approach could be something like:

import torch

try:
    import intel_extension_for_pytorch as ipex
except Exception:
    pass

DEVICE = (
    "cuda" if torch.cuda.is_available() else (
        "xpu" if torch.xpu.is_available() else (
            "mps" if torch.backends.mps.is_available() else (
                "cpu"
            )
        )
    )
)

if DEVICE == 'cuda':
    from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
    from apex.normalization import MixedFusedRMSNorm as RMSNorm
else:
    if hasattr(torch.xpu, "IpexRmsNorm"):
        from .fused_rmsnorm import RMSNorm
    else:
        from .rmsnorm import RMSNorm
    from torch.nn import LayerNorm
# ...

which I believe should work

vadam5 commented 3 months ago

I've also had this same issue recently. Can someone please fix this bug?

clintg6 commented 2 months ago

+1