microsoft / Megatron-DeepSpeed

Ongoing research training transformer language models at scale, including: BERT & GPT-2
Other
1.9k stars 345 forks source link

[XPU] Support fused_rms_norm on XPU device #431

Closed ys950902 closed 3 months ago

ys950902 commented 4 months ago

For cuda fused_rms_norm is set by default, support fused_rms_norm on XPU device.

rogerxfeng8 commented 3 months ago

@tjruwase would you help review?

polisettyvarma commented 3 months ago

@rogerxfeng8 @tjruwase should ipex import under try except block, right ?

tjruwase commented 3 months ago

@polisettyvarma, good catch. I think ipex import could also be under if get_accelerator().device_name() == 'xpu': similar to cuda case.

@rogerxfeng8, can you please help fix. Thanks!

ys950902 commented 3 months ago

@polisettyvarma, good catch. I think ipex import could also be under if get_accelerator().device_name() == 'xpu': similar to cuda case.

@rogerxfeng8, can you please help fix. Thanks!

Thanks for your suggestions, I have modified it on https://github.com/microsoft/Megatron-DeepSpeed/pull/436.