llvm / torch-mlir

The Torch-MLIR project aims to provide first class support from the PyTorch ecosystem to the MLIR ecosystem.
Other
1.35k stars 505 forks source link

NativeBatchNorm op is being decomposed before TorchToTosa lowering #618

Closed anupgangwar closed 2 years ago

anupgangwar commented 2 years ago

From the e2e test framework, I see that the generated MLIR no longer contains the batch_norm op for BatchNorm*_basic tests. The op appears to have been decomposed further (as per it definition). This appears to be a side-effect of PR #594. Is there a way to not have the decomposition in the TOSA backend path?

Tagging @cathyzhyi @silvasean @sjarus.

silvasean commented 2 years ago

Why do you want a batch norm op? It seems simpler to lower the individual decomposed ops, which form a smaller set.

We could design a way to avoid decomposing it, but I would like to understand the motivation.

anupgangwar commented 2 years ago

For the near term, there are a few decomposed ops which are not supported in TorchToTosa yet. Then there are some opens around the .size op which are under discussion.

If this decomposition can be avoided in the TOSA backend path for now, it will again get the ResNet18 passing. Once the full set of ops are supported we can look at whether to enable the decomposition again.

silvasean commented 2 years ago

For this specific case, I would prefer to fix this forward, as suppressing the decomposition temporarily without a failing upstream test to motivate it doesn't seem desirable. This type of thing will eventually be caught by CI and fixed on the spot once ResNet passes on TOSA upstream :)

sjarus commented 2 years ago

Stepping in to assist @anupgangwar here after an internal conversation - there's some confusion here related to to previous advise @cathyzhyi had posted in #594 - the suggestion there was understood as modifying the lowering of aten.batch_norm to that of aten.native_batch_norm as an op. This is doable according to Anup.

However, the decomposition isn't actually to aten.native_batch_norm as an op, but to the mechanical primitives of batch_norm. Maybe that's the intent of native_batch_norm, but then both aten.batch_norm and aten.native_batch_norm are listed in GeneratedAtenOps.td , so we expected to see a single op aten.native_batch_norm in the IR.

silvasean commented 2 years ago

Is there some specific op in the decomposition that is hard to support? It should just be basic stuff like sqrt, add, mul, divide, sum, etc. which seem like we want to support anyway. It feels like the right long-term direction for the TOSA backend to support those ops, so since there is not immediate test breakage upsream, it's hard to justify an upstream change to suppress that behavior. If this is a problem for your locally, you can temporarily comment out that pattern in DecomposeComplexOps.

GeneratedAtenOps.td just reflects the contents of the registry, and provides no specific guarantee that such an op is present at the backend interface. In this case, I think that generally backends should prefer the fully-decomposed ops (for example, TOSA does not have a BATCH_NORM operator, so there is no value in keeping it un-de-composed AFAICT).

sjarus commented 2 years ago

Oh I agree all the basic tensor level computational ops are fine - that's what the TOSA decomposition also looks like anyway. None of that is a problem. If the decomposition entirely consisted of those, we'd not notice and nothing would fail since they are already present. We have decomposed batch_norm elsewhere, including in TensorFlow.

The real issue here is that the decomposition does a few tricky things around generating subtensor level constructs, e.g. the outputs of aten.size.int are ints .

TOSA being tensor level, is more semantically suited to sit at a level higher than subtensor ops - this particular decomposition takes us rather far down in layering that we are compelled to figure how to work our way back up.

This is not a problem with LinAlg since it's not a tensor level dialect like TOSA is. So this decomposition is beneficial from the LinAlg path perspective, but not necessarily helpful for TOSA. It's possible to take a network that is TOSA expressible and decompose it to a level of abstraction where TOSA expression then becomes hard or even not possible.

But that's the current situation - the decomposition forces us down from an abstraction where we had a working implementation to one where it's not quite straightforward to accomplish.

Alternatively we could ignore such ops, the result being a heterogeneous IR which would require calling torch-to-linalg in addition to tosa-to-linalg. There's a separate ongoing effort to enable this nicely through tosa.custom, but to get there might likely leave ResNet18 broken for the TOSA backend for a while (Anup's away on vacation soon).

anupgangwar commented 2 years ago

Just to add to that, we did have the ResNet18 and BatchNorm's passing before PR #594 disabled them (removed from xfail_sets.py). So, in a way the CI would have caught it :).

silvasean commented 2 years ago

Just to add to that, we did have the ResNet18 and BatchNorm's passing before PR #594 disabled them (removed from xfail_sets.py). So, in a way the CI would have caught it :).

Oh, that's not good. Effectively #594 just arbitrarily disabled TOSA tests. I have asked that it be reverted!!!

anupgangwar commented 2 years ago

Thanks @silvasean.

silvasean commented 2 years ago

Alternatively we could ignore such ops, the result being a heterogeneous IR which would require calling torch-to-linalg in addition to tosa-to-linalg. There's a separate ongoing effort to enable this nicely through tosa.custom, but to get there might likely leave ResNet18 broken for the TOSA backend for a while (Anup's away on vacation soon).

Looking at the IR, it seems like the decomposition is using the view-like op to reshape [N] to [1,N,1,1], which seems like a reasonable thing to pattern match and lower to TOSA (that's effectively what we do in TorchToLinalg too). There are also error checks inserted (torch.runtime.assert), which I think that TOSA should have some sort of answer to, even if it means deleting them. (since TOSA doesn't currently define the error semantics, I think that just deleting torch.runtime.assert and ops only used by those is equivalent to the current state of TOSA)

silvasean commented 2 years ago

I'm going to keep this issue open for now as a place where @Shukla-Gaurav and @sjarus @anupgangwar can discuss a path forward on this.

sjarus commented 2 years ago

I think aten.size.int -> int was the biggest impediment @anupgangwar faced. It's tricky to work with from TOSA's layering point of view.

BatchNorm can still be done fine in TOSA using analogous primitives; Anup's previous batch_norm legalization to TOSA is in #543 , the specific generated sequence being here: https://github.com/llvm/torch-mlir/pull/543/files#diff-32507a97012015d11da21d6a5c072e35c87fbd60480114a48bf4f0fc93da8d0eR495

cathyzhyi commented 2 years ago

Hey @sjarus, may I ask why aten.size.int -> int is the problem? It can be decomposed to tensor ops and arithmetic ops. There is nothing specific to linalg.

Another thing is that as @silvasean mentioned, the result of aten.size.int -> int is only used for runtime assertion. Maybe TOSA could just remove the runtime assertion and ops only used by assertion are cleaned up.

sjarus commented 2 years ago

Hey @sjarus, may I ask why aten.size.int -> int is the problem? It can be decomposed to tensor ops and arithmetic ops. There is nothing specific to linalg.

It could be done but amounts to a layering problem going down subtensor level, forcing back up and then down again. There's an alternative...

Another thing is that as @silvasean mentioned, the result of aten.size.int -> int is only used for runtime assertion. Maybe TOSA could just remove the runtime assertion and ops only used by assertion are cleaned up.

... yes this could be done, or A) general runtime assertions can be inserted in TOSA to LinAlg instead or B) added as part of dynamic/runtime capability into the existing TorchToTosa batchnorm, thereby preserving useful capability but layered better in my view. Either of these seems preferable to the current state.

Could the runtime assertions be controllable so the TOSA backend can independently manage this as suggested above ?

silvasean commented 2 years ago

Could the runtime assertions be controllable so the TOSA backend can independently manage this as suggested above ?

No. It's not semantically correct to lower the op without such assertions, so we don't want that in the main Torch lowering flow. Backends should be able to ignore the assertions though, so perhaps we need to design a better op than a loose assert op (maybe something where the assertion code is in a region that can be trivially discarded).

sjarus commented 2 years ago

so perhaps we need to design a better op than a loose assert op (maybe something where the assertion code is in a region that can be trivially discarded).

Yeah that sounds like it could be a reasonable option too! Anything that lets us trivially identify and either discard or replace assertions with more layer-appropriate ones would work fine.

Shukla-Gaurav commented 2 years ago

@silvasean @anupgangwar @sjarus The changes has been reverted #619.

cathyzhyi commented 2 years ago

@Shukla-Gaurav Can you close this one and open another issue describing what's left to be done?

silvasean commented 2 years ago

ResNet is working on TOSA so I think we found a solution to the immediate problem here.