pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.42k stars 439 forks source link

mlir-npcomp intersects with torch-xla #2854

Open byronyi opened 3 years ago

byronyi commented 3 years ago

Background

ATen device capture is an attempt to produce MLIR from tracing PyTorch program running on CPU. It largely modeled after torch/xla in its very first version contributed by @stephenneuendorffer from Xilinx.

Old API (pseudo-device, similar to torch/xla):

import npcomp.frontends.pytorch as torch_mlir

dev = torch_mlir.mlir_device()
t0 = torch.randn((4,4), device=dev)
t1 = torch.randn((4,4)).to(dev)
t2 = t0 + t1
t2_mlir = torch_mlir.get_mlir( t2 )
t2_cpu = t2.to('cpu')

The device API is still recommended in the official tutorial for adding backend support.

@stellaraccident refactored (and greatly reduce the LOC from ~10k to <1k) the code to utilize the c10 dispatcher. Stella also changed the user-facing API from pseudo device to a Python context associated with local dispatch key.

New API:

import torch
import torch_mlir

lhs = torch.rand(2, 3)
rhs = torch.rand(3, 4)

mb = torch_mlir.ModuleBuilder()
with mb.capture_function("mm", [lhs, rhs]) as f:
  result = torch.mm(lhs, rhs)
  f.returns([result])

mb.module.operation.print()

Output:

module  {
  func @mm(%arg0: !numpy.ndarray<[2,3]:f32>, %arg1: !numpy.ndarray<[3,4]:f32>) -> !numpy.ndarray<[2,4]:f32> {
    %0 = torch.kernel_call "aten::mm" %arg0, %arg1 : (!numpy.ndarray<[2,3]:f32>, !numpy.ndarray<[3,4]:f32>) -> !numpy.ndarray<[2,4]:f32> {sigArgTypes = ["Tensor", "Tensor"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
    return %0 : !numpy.ndarray<[2,4]:f32>
  }
}

More examples could be found here: https://github.com/llvm/mlir-npcomp/tree/main/frontends/pytorch/test/acap_export

Under the hood of the new API, c10 dispatcher picks up a local dispatch key associated with mb.capture_function context manager, and a backend fallback function picks up the boxed kernel call on torch::jit::stack and produces generic MLIR (1:1 mapping to boxed kernel call). It then re-dispatch to CPU backend to get the shape and dtype of ATen tensors.

Despite that, the c10 dispatcher has some caveats and we suspect that the it is not really how the mode-based backend fallback mechanism is supposed to be used, notably on convolution, copy_ and factory functions like arange. For us it also record the shape and dtype too early (more on that later to support dynamic shape) but this could be altered.

Design goal

The reason we look into MLIR is two-fold. XLA HLO IR has some nice properties, but it does not support dynamic shape (except for padding) and MHLO will probably fix that in the future. In the meanwhile we would like to plug in backend specific intrinsics for some of the custom ops (notably torchvision ops, or ctc loss in seq2seq models) when migrating existing PyTorch users, and we are looking into better alternatives to XLA custom call.

Xilinx and some of the custom training ASIC vendors (AFAIK) are also moving towards MLIR for easier interoperability between frameworks and their software/hardware stack. They are satisfied to export MLIR some where in the front end stack.

Proposed changes

Option 1

Export HLO graph from torch/xla and translate into MHLO. Not really what we look into for adding dynamic shape and custom ops support as it does not enhance the expressiveness of current torch/xla frontend.

Option 2

Adding back the pseudo-device API in mlir-npcomp/acap to probably workaround the caveats. It would be similar to the wrapper-based backend fallback in contrast to mode-based backend fallback. It retains the CLOC advantage while switching to (IMHO) canonical extension points for backends. We also would like to re-visit the shape and dtype static inference vs runtime tracing to support dynamic shape.

Option 3

Migrate torch/xla frontend to MLIR (some overlap with option 2), e.g. adding an XLA dialect and directly go to MHLO (probably mixed with custom ops untouched). Basically a re-write of current XLA type dispatch using backend fallback, and XLA IR into MLIR dialect (probably using chlo which provides similar functionalities to XlaBuilder). Different fallback strategy could be employed: either using "copy tensors back to CPU and call CPU kernels" similar to what AtenXlaTypeDefault does today, or plugin backend specific intrinsics (require support from runtime side).

Discussion

Option 1 could be a quick shot and doable in 2-4 weeks. I am leaning into option 2 in the mid-term (1-2 month?) and gradually move to option 3 in the long term (6-12 months) and it does enhance the UX of torch/xla. I guess the XLA/TPU stack is also moving to MLIR so we are looking for some early feedbacks here.

PS: great thanks to @silvasean for initiating and coordinating the discussion. Overall roadmap of mlir-npcomp could be found here: https://github.com/llvm/mlir-npcomp/blob/main/docs/roadmap.md

stellaraccident commented 3 years ago

+1 on option 2, which I don't think will be that hard (earlier on, I did make an attempt at it and so had issues, but I think I may have been holding it wrong).

silvasean commented 3 years ago

+1 on option 2 as well.

ailzhang commented 3 years ago

Thanks for providing detailed context in the issue! I haven't gone through all details here yet, will take another look on Monday! cc: @ezyang @wconstab @asuhan @bdhirsh for more visibility.

byronyi commented 3 years ago

I just found the symbolic shape DSL and structured kernel RFCs as I was thinking that dynamic shape support would need something like meta tensors where you can run the operator statically with type and shape but no real data or computation.

Not sure how long it would take to migrate most existing kernels into structured ones (or even in symbolic shape DSL), but it looks very promising, especially the part that "allowing build of TS compiler independently without aten library".

ezyang commented 3 years ago

So if I understand correctly, the menu of choices from Option 1 - Option 3 represent most conservative to most radical changes to the torch_xla project. Although no one seems to be advocating for it in the short-mid term, Option 3 in particular seems like a really radical change for the torch_xla project, which has from the start been HLO first. Because torch_xla is a real project that people are using to do actual training, we have to keep the plane flying as we change out the engine.

I guess the XLA/TPU stack is also moving to MLIR so we are looking for some early feedbacks here.

This seems to imply that there is organizational buy in from Google? But I'd feel better hearing from, e.g., @JackCaoG, that this is indeed globally Google's plan on record.

Something else that is a little awkward is, independently of this thrust, there's been a recent effort headed by @asuhan and @wconstab to figure out how to factor torch_xla's codebase into an independent, functional lazy tracing mechanism that can be used in other contexts. We'd only been talking about internally recently, but I just put up our working design https://github.com/pytorch/rfcs/pull/18 that discusses some of the considerations. The short answer is that @asuhan has a split out version of XLA that desugars to a simple ATen dialect IR, which then can be further lowered (currently, it only lowers to HLO). This rewrite, however, was done expressly with the intent of otherwise being as low risk as much as possible (so for example, it doesn't do any of the refactors that you've done in your fork, including removal of the ATen operators classes to reduce LOC, or switch to using backend fallback as proposed in Option 2).

Speaking as a core PyTorch developer, here are my thoughts:

Also, cc @smessmer @ngimel ; these two folks are trying to resolve a major bug in backend fallbacks afflicting master (perhaps that's what stopped @stellaraccident from successfully making this work last time).

ezyang commented 3 years ago

@byronyi

I just found the symbolic shape DSL and structured kernel RFCs as I was thinking that dynamic shape support would need something like meta tensors where you can run the operator statically with type and shape but no real data or computation.

Unfortunately I'm not sure how much this actually helps you. In particular meta tensors still have exact shape/dtype information everywhere (I did want to support symbolic, but it was too hard to figure out how to do that at the same time as other the design constraints), so you won't be able to get dynamic shapes without rerunning the network, at which point XLA's current strategy works perfectly fine.

eellison commented 3 years ago

I think it's possible that JIT symbolic shape analysis could help you here, but there's a lot left to be designed...

shauheen commented 3 years ago

Thanks @ezyang, I will need some more time to read through the RFC and the other proposal but just to comment on this:

I guess the XLA/TPU stack is also moving to MLIR so we are looking for some early feedbacks here.

This seems to imply that there is organizational buy in from Google? But I'd feel better hearing from, e.g., @JackCaoG, that this is indeed globally Google's plan on record.

Want to clarify that MLIR will be used for optimizing many things including XLA however there are no current plans to replace XLA.

silvasean commented 3 years ago

If the goal is to avoid excessive recompilation for dynamic shapes, it probably makes more sense to capture multiple traces and opportunistically erasing shapes to a common supertype if the traces are structurally the same -- should converge to a most de-refined supertype for all ops in the steady state. This could be done "under the hood".

byronyi commented 3 years ago

Thanks @ezyang, I will need some more time to read through the RFC and the other proposal but just to comment on this:

I guess the XLA/TPU stack is also moving to MLIR so we are looking for some early feedbacks here.

This seems to imply that there is organizational buy in from Google? But I'd feel better hearing from, e.g., @JackCaoG, that this is indeed globally Google's plan on record.

Want to clarify that MLIR will be used for optimizing many things including XLA however there are no current plans to replace XLA.

I was just skimming through TPU headers and found XRT variants of TpuCompile: https://github.com/tensorflow/tensorflow/commit/c31e582af5cc8fe4190bcb4743b6914a2c634f70

It also accepts MLIR module directly: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/tpu/tpu_ops_c_api.h#L111-L120

So TPU runtime still accepts XLA/HLO, and whether frontends other than TF migrates to MLIR should largely depend on if we would like to share some of the optimization passes between these lowering paths (which is exactly what MLIR is primarily designed for IMHO).

ailzhang commented 3 years ago

@byronyi @stellaraccident @silvasean For option 2, I wonder what do you expect to gain compared to your current mode-base approach? I took a quick look at your mode-based approach (which toggles the included keyset of TLS), although it's slightly different from how xla carries keys on the tensor, it's still a way to tweak our dispatcher and they're not fundamentally different. I feel issues you're facing about conv etc might not go away with option 2, so I'm curious what's the main motivation for it. (Following our device extension guide could be one reason ;) but we are also interested in learning about issues that are blocking you so that we can see if there's anything we can help :D)

stellaraccident commented 3 years ago

For option 2, I wonder what do you expect to gain compared to your current mode-base approach?

This is my question as well. I've swapped a lot of the context out and would have a hard time articulating it, but ~6 months ago, I did read/trace through all of the PyTorch code in this area, thought I understood it, and did try a version that was more true to the device extension guide. As you note, this doesn't change a lot of the mechanics, and indeed didn't magically fix anything. But I also discount this experience because a) it isn't memorialized anywhere and I may have been holding it wrong even still, and b) there were a lot of other issues with the dispatcher at that time that it appears have been somewhat fixed in the intervening time.

I still suspect that there are "just bugs" here, either on the PyTorch or npcomp side. The areas that were particularly fiddly were the fallback handler (which was, surprisingly, a somewhat leaky abstraction in practice for doing what I actually wanted -- a way to snoop on the dispatch stream with generality) and the bits on the npcomp that are trying to track tensor equality to stitch the graph back together.

If I were putting more cycles into this, I would strip out all of the special cases from the npcomp tracer, see how things work now without them and then walk it back forward, fixing issues upstream vs patching around them. I'd also like someone who knows better to look at the way that the fallback handlers are being used and at least give an ack that it should work. The other examples of using the dispatcher that I could find all use explicit kernel registration, which at least seems to be the more trodden path.

ailzhang commented 3 years ago

@stellaraccident Overall we're interested in unblocking you if there's any dispatcher related bugs/issues. :D So I'm personally interested in what's your current blockers in mode-based approach before we choose among option1/2/3 :D. @smessmer recently fixed some issues for backend fallback (details in https://dev-discuss.pytorch.org/t/backend-fallbacks/195 ) and it could enable xla to AtenXlaTypeDefault codegen I think. I plan to take a look there in 1-2 weeks to see if we can take advantage of it. I think in your approach which uses backend fallback it might have hit some unsupported cases in backend fallback in the past, maybe worth a try after this fix?

sanjoy commented 3 years ago

We (TensorFlow/XLA GPU team) are incrementally refactoring XLA GPU to use MHLO and LMHLO (latter being bufferized version of MHLO), with the intent of supporting unbounded dynamic shapes, better custom calls (both of which are important for TensorFlow as well) amongst other things.

OOC, can you share why you need more flexibility in XLA custom calls?

The current "cut point" for this transition is here where we translate XLA HLO graph into LMHLO and go from there.

This is very much a work in progress, but in the spirit of "keep[ing] the plane flying as we change out the engine", we're trying to land all of the work in the default load bearing execution codepath for XLA GPU.

byronyi commented 3 years ago

OOC, can you share why you need more flexibility in XLA custom calls?

There are two kinds of custom calls as we see: (1) standard XLA ops getting thunked to vendor libraries for better performance (2) custom ops implemented in CUDA that have not yet be standardized/upstreamed into core TF/PyTorch. Some cases in category (2) could be implemented as composite ops using framework primitives, but they got fused by hand mostly for performance reason. In the long run we expect to see less (2) as the compiler stack getting more powerful. As we learnt from the great UX in TF/XLA side, better interoperability between native CUDA ops and XLA/GPU runtime would certainly help "keep[ing] the plane flying as we change out the engine".

Aside from the XLA/GPU story, we are also discussing with other training accelerator vendors on long term TF/PyTorch support strategy. Some has (incomplete) direct support for framework primitives, but finds it difficult to keep track of the frequent changes on framework side. MHLO (1) provides a narrow waist to cover most frontend ops (2) leaving the possibility to tackle dynamic shapes in the future. It is also possible to bypass XLA lowering for some ops (e.g. directly implemented in proprietary libraries or hard-wired in HW) thanks to the "multi-levelness" of MLIR. Again, the benefits mostly come from the smooth migration path, and IMHO it would be a great engineering success by itself.

sanjoy commented 3 years ago

Thanks @byronyi; by custom call I meant (2). I consider (1) to be just an implementation detail. Have you looked into wrapping PT implementations in XLA's custom-calls?

It is also possible to bypass XLA lowering for some ops

This is convenient, but as you know it is a double edged sword; if we make this too easy / streamlined then we may end up with too many of handwritten lowerings, which prevents getting the full benefit of HLO.

byronyi commented 3 years ago

This is convenient, but as you know it is a double edged sword; if we make this too easy / streamlined then we may end up with too many of handwritten lowerings, which prevents getting the full benefit of HLO.

For XLA/GPU:

I agree; IMHO it is a trade-off between getting more users onboard for lazy tensor API w/ perf perks in short term when we try to optimize the compiler stack to support dynamic shapes in the front end. If MHLO will have better support for either (1) dynamic shape or (2) custom call in near future, we probably won't spend too much time to figure out custom calls in HLO.

The line between being pragmatic and short-sighted (or user-pleasing?) is probably too thin here :p

For other DSAs:

And we have seen DSA vendors having slightly different opinions on which ops are better lowered via HLO and which ops are better lowered directly to their own stack. We probably would like to leave room for those ideas (at the end of the day raw perf speaks) but our work are mainly focused on HLO/MHLO.

stale[bot] commented 3 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

stephenneuendorffer commented 2 years ago

@byronyi I'm curious if you've made any further progress on this?

byronyi commented 2 years ago

@byronyi I'm curious if you've made any further progress on this?

We have a prototype that utilizes PyTorch "meta" tensors for shape inference, and then capture the PyTorch ATen kernel execution sequence into MLIR graph of generic torch.kernel_call ops. It features an API that is compatible to Torch-XLA while much more ops are supported compared to the prototype you did last year (w/ meta tensors which refactor shape inference and computation into separate kernels).

We have not started work on converter from ATen to MHLO, as there are still PyTorch specific issue that truncates the graph too frequently (like https://github.com/pytorch/pytorch/issues/62320). Initially we thought it was an XLA specific issue (that might get solved once dynshape support in MHLO launches) but then it turns out to be a PyTorch one in the frontend. You can find more details in the issue.

Recently the PyTorch compiler team published a post on their plan for leaning into eager-first training and deployment (except for mobile and edge). They specifically mentioned that

We must reimagine what an AI Compiler looks like when deeply integrated with eager-mode execution, and explore more dynamic and flexible approaches and taking inspiration from successful JIT compilers in other domains (such as JavaScript).

To facilitate this shift, the primary focus for the TorchScript stack (whole program capture) will become Mobile/Edge.

Techniques to get partial graphs, such as Lazy Tensors and explicit fusion, can provide speedups — but vendors should not require on these techniques to have competitive performance.

All of these are bold bets. If the industry decide to take this path, these challenges need to be dealt with even before the IR graph building. Right now it seems to be a safe choice to continue the work on lazy tensors (and lowering from TorchScript IR, thanks to @stellaraccident's reminder), but we are re-thinking the role it plays in the PyTorch compiler stack if the upstream decides to abandon whole program capture for PyTorch workloads on cloud.

stellaraccident commented 2 years ago

PyTorch is either being bold/right or is entering the oscillation phase that Tensorflow went into years ago, producing a bunch of partially complete, incompatible interface points. Time will tell, I guess. They seem pretty responsive to the community, so I'm hopeful that feedback over time will refine the path as it evolves.

One thing I've not seen spelled out yet: they are definitely not high on the whole program aspects of TorchScript, but TorchIR seems to still be a viable tracing target. Even if things go through upheaval over ~years, if some things at that level stay around and overlap, we may have time to adapt.

It seems clear that whole program capture cases will continue to be a bit of a fringe thing, so being N-1 on approach isn't the worst thing in the world (ie. Favor stability) while things evolve.

Just some random thoughts.

stephenneuendorffer commented 2 years ago

Thanks for this. I think there are 2 separate questions here... First: What are the basic abstractions that we can use for device dispatch? Ideally these abstractions wouldn't change too much and would support operations that we know are important for todays devices (like operator fusing in a graph). Second: What is necessary with current and future devices to get the highest level of performance? I see this as more of a question of extent of optimization than anything else.

It seems that the issues you see are more about the second question (i.e. the size of the graph) rather than the first basic plumbing for the MLIR representation. Do you think your prototype is sufficient for the first at this point?