intel / intel-extension-for-pytorch

A Python package for extending the official PyTorch that can easily obtain performance on Intel platform
Apache License 2.0
1.53k stars 236 forks source link

Questions about the eager mode #676

Closed zhangwm-pt closed 1 week ago

zhangwm-pt commented 1 month ago

Describe the issue

According to the official description of Intel Extension for PyTorch (IPEX), it supports both Eager mode and Graph mode. In my understanding, in Eager mode, the operators in PyTorch are replaced with IPEX's ATen operator implementations. This way, we only need to register a new operator implementation in IPEX to achieve acceleration. However, in the current tutorials, it seems that Eager mode only involves some frontend optimizations of graph operators, and I haven't seen how the underlying implementation of the operator is replaced. Or do I have to write this optimized implementation into PyTorch’s ATen operators rather than in IPEX? Can someone help clarify this? This would help me add a more optimized implementation for a specific operator in IPEX.

jgong5 commented 1 month ago

Adding operators to IPEX can leverage the same custom op registration mechanism from PyTorch. There are a couple of similar examples in IPEX. You may refer to things like this: https://github.com/intel/intel-extension-for-pytorch/blob/5cc852eb850f8b42ac2d0fcca5709e6076e8042d/csrc/cpu/aten/PagedAttention.cpp#L56 And more examples are available under https://github.com/intel/intel-extension-for-pytorch/tree/main/csrc/cpu/aten

May I know what particular ops you plan to support?

zhangwm-pt commented 1 month ago

Adding operators to IPEX can leverage the same custom op registration mechanism from PyTorch. There are a couple of similar examples in IPEX. You may refer to things like this:

https://github.com/intel/intel-extension-for-pytorch/blob/5cc852eb850f8b42ac2d0fcca5709e6076e8042d/csrc/cpu/aten/PagedAttention.cpp#L56

And more examples are available under https://github.com/intel/intel-extension-for-pytorch/tree/main/csrc/cpu/aten May I know what particular ops you plan to support?

Thank you for your reply. Indeed, I can add new custom operators to IPEX in the csr/cpu/aten directory. However, my actual question is, if I use IPEX Eager mode to execute a model, then according to my understanding, the implementation of a certain operator in PyTorch would be replaced with a new accelerated implementation in IPEX. But from looking at the IPEX Eager mode code, I haven't found where IPEX is replacing the native implementation of PyTorch's ops with IPEX's implementation.

model = ...

# eager mode
import intel_extension_for_pytorch as ipex
model = ipex.optimize(model)

with torch.no_grad():
    model(data)

ipex.optimize only optimize graph (fusion/layout etc...), I didn't see where the implementation was replaced.

jgong5 commented 1 month ago

If you are talking about how IPEX overrides certain PyTorch ops, it is not done with ipex.optimize but still done with the op registration mechanism we were discussing. As long as IPEX registers the op with same schema, the corresponding op from PyTorch would be overridden. Here is an example: https://github.com/intel/intel-extension-for-pytorch/blob/5cc852eb850f8b42ac2d0fcca5709e6076e8042d/csrc/cpu/aten/LayerNorm.cpp#L203

Please note that only a few ops are overridden by IPEX. We optimize most of ATen ops directly in PyTorch upstream.

zhangwm-pt commented 1 month ago

If you are talking about how IPEX overrides certain PyTorch ops, it is not done with ipex.optimize but still done with the op registration mechanism we were discussing. As long as IPEX registers the op with same schema, the corresponding op from PyTorch would be overridden. Here is an example:

https://github.com/intel/intel-extension-for-pytorch/blob/5cc852eb850f8b42ac2d0fcca5709e6076e8042d/csrc/cpu/aten/LayerNorm.cpp#L203

Please note that only a few ops are overridden by IPEX. We optimize most of ATen ops directly in PyTorch upstream.

Can you explain in detail why using this registration mechanism can replace PyTorch's implementation? According to my understanding, IPEX relies on libtorch.so, but libtorch.so does not depend on IPEX. If a model is originally created through libtorch.so, why can the implementation in IPEX replace the implementation in libtorch.so? Because the model is created through torch, not IPEX. I am planning to introduce flash attention-3 in IPEX. If you can provide some explanation or additional resources, I would greatly appreciate it.

jgong5 commented 1 month ago

Can you explain in detail why using this registration mechanism can replace PyTorch's implementation? According to my understanding, IPEX relies on libtorch.so, but libtorch.so does not depend on IPEX. If a model is originally created through libtorch.so, why can the implementation in IPEX replace the implementation in libtorch.so? Because the model is created through torch, not IPEX. I am planning to introduce flash attention-3 in IPEX. If you can provide some explanation or additional resources, I would greatly appreciate it.

The op dispatch mechanism in PyTorch would look up the registration table for the corresponding implementation registered via the code I pointed to you. If the op schema is the same, the original implementation would be overridden by the new implementation. The registration table is process global, not specific to a particular .so file.

Good to know you plan to contribute. The flash attention implementation in IPEX is here fyi: https://github.com/intel/intel-extension-for-pytorch/blob/main/csrc/cpu/aten/FlashAttention.cpp

zhangwm-pt commented 1 month ago

Can you explain in detail why using this registration mechanism can replace PyTorch's implementation? According to my understanding, IPEX relies on libtorch.so, but libtorch.so does not depend on IPEX. If a model is originally created through libtorch.so, why can the implementation in IPEX replace the implementation in libtorch.so? Because the model is created through torch, not IPEX. I am planning to introduce flash attention-3 in IPEX. If you can provide some explanation or additional resources, I would greatly appreciate it.

The op dispatch mechanism in PyTorch would look up the registration table for the corresponding implementation registered via the code I pointed to you. If the op schema is the same, the original implementation would be overridden by the new implementation. The registration table is process global, not specific to a particular .so file.

Good to know you plan to contribute. The flash attention implementation in IPEX is here fyi: https://github.com/intel/intel-extension-for-pytorch/blob/main/csrc/cpu/aten/FlashAttention.cpp

I got it! Thanks very much!

jgong5 commented 1 month ago

cc @liangan1

huiyan2021 commented 1 week ago

Closing this issue, feel free to reopen if needed.