tensorflow / mlir-hlo

398 stars 70 forks source link

mlir in tensorflow training? #23

Open Cuttstage opened 2 years ago

Cuttstage commented 2 years ago

Hi, Is it possible for using mlir to make LLVM IR in machine training like GPU support? I can not find any code in tensroflow to use mlir turn back to tensorflow executor. Therefore, mlir is only useful for inference? I wonder if tensorflow could change make some ops into IR process and then merge IR result back to tensorflow process?

joker-eph commented 2 years ago

MLIR is used in Grappler now, which is plugged into the executor. MLIR is also used more and more to implement the XLA CPU/GPU compiler (which emits LLVM), and it is used for both inference and training (it powers JAX for example).

You'd have to be more specific about what you're trying to achieve though? MLIR has been used so far for fairly low-level pieces of infrastructure inside the TensorFlow/XLA ecosystem.

joker-eph commented 2 years ago

Seems like you have some misconfigured auto-reply here :)

Cuttstage commented 2 years ago

MLIR is used in Grappler now, which is plugged into the executor. MLIR is also used more and more to implement the XLA CPU/GPU compiler (which emits LLVM), and it is used for both inference and training (it powers JAX for example).

You'd have to be more specific about what you're trying to achieve though? MLIR has been used so far for fairly low-level pieces of infrastructure inside the TensorFlow/XLA ecosystem.

Thanks for your reply. I have found some code about tfg in the lastest tensorflow repo. I am a newer to MLIR and hope to use this feature in our machine train. :)

stellaraccident commented 2 years ago

MHLO (which is what this repo contains) is the native IR of JAX, which is used heavily for training (on both CPU/GPU/TPU via XLA). However, "training" can mean many things. For example, here is a prototype of a new API we are working on for saving off an entire Jax training program (in this case, a simple mnist model) so it can be run offline via IREE on single CPU/GPU systems (which happens to include a large swath of mobile and embedded devices):

This is but one example of a training setup. If looking for distributed training, that is a more advanced topic. Also, the above is just a prototype: we are looking to finish it for everyone to use in the coming months.

Cuttstage commented 2 years ago

MHLO (which is what this repo contains) is the native IR of JAX, which is used heavily for training (on both CPU/GPU/TPU via XLA). However, "training" can mean many things. For example, here is a prototype of a new API we are working on for saving off an entire Jax training program (in this case, a simple mnist model) so it can be run offline via IREE on single CPU/GPU systems (which happens to include a large swath of mobile and embedded devices):

This is but one example of a training setup. If looking for distributed training, that is a more advanced topic. Also, the above is just a prototype: we are looking to finish it for everyone to use in the coming months.

Thanks a lot. We will have a try. But right now our producing model is running on tf. It will be a great work to move tf to jax.