llvm / llvm-project

The LLVM Project is a collection of modular and reusable compiler and toolchain technologies.
http://llvm.org
Other
28.78k stars 11.9k forks source link

[MLIR][analysis] Lattice: Fix automatic delegation of meet to lattice value classes #82620

Closed andidr closed 5 months ago

andidr commented 8 months ago

The class Lattice should automatically delegate invocations of the meet operator to the meet operation of the associated lattice value class if that class provides a static function called meet. This process fails for two reasons:

  1. Lattice::has_meet checks for a member function meet without arguments of the lattice value class, although it should check for a static member function.

  2. The function template Lattice::meet<VT>() implementing the default meet operation directly in the lattice is always present and takes precedence over the delegating function template Lattice::meet<VT, std::integral_constant<bool, true>>().

This change fixes the automatic delegation of the meet operation of a lattice to the lattice value class in the presence of a static meet function by conditionally enabling either the delegating function template or the non-delegating function template and by changing Lattice::has_meet so that it checks for a static meet member function in the lattice value type.

The test from TestSparseBackwardDataFlowAnalysis.cpp is changed, such that the meet function is not provided directly in the WrittenTo lattice, but by the Lattice base class in order to trigger delegation to a lattice value class.

llvmbot commented 8 months ago

@llvm/pr-subscribers-mlir

Author: Andi Drebes (andidr)

Changes The class `Lattice` should automatically delegate invocations of the meet operator to the meet operation of the associated lattice value class if that class provides a static function called `meet`. This process fails for two reasons: 1. `Lattice::has_meet` checks for a member function `meet` without arguments of the lattice value class, although it should check for a static member function. 2. The function template `Lattice::meet<VT>()` implementing the default meet operation directly in the lattice is always present and takes precedence over the delegating function template `Lattice::meet<VT, std::integral_constant<bool, true>>()`. This change fixes the automatic delegation of the meet operation of a lattice to the lattice value class in the presence of a static `meet` function by conditionally enabling either the delegating function template or the non-delegating function template and by changing `Lattice::has_meet` so that it checks for a static `meet` member function in the lattice value type. The test from `TestSparseBackwardDataFlowAnalysis.cpp` is changed, such that the `meet` function is not provided directly in the `WrittenTo` lattice, but by the `Lattice` base class in order to trigger delegation to a lattice value class. --- Full diff: https://github.com/llvm/llvm-project/pull/82620.diff 2 Files Affected: - (modified) mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h (+5-3) - (modified) mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp (+42-18) ``````````diff diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h index b65ac8bb1dec27..7aadd5409cc695 100644 --- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h @@ -132,14 +132,15 @@ class Lattice : public AbstractSparseLattice { /// analysis, lattices will only have a `join`, no `meet`, but we want to use /// the same `Lattice` class for both directions. template - using has_meet = decltype(std::declval().meet()); + using has_meet = decltype(&T::meet); template using lattice_has_meet = llvm::is_detected; /// Meet (intersect) the information contained in the 'rhs' value with this /// lattice. Returns if the state of the current lattice changed. If the /// lattice elements don't have a `meet` method, this is a no-op (see below.) - template ::value>> + template ::value> * = nullptr> ChangeResult meet(const VT &rhs) { ValueT newValue = ValueT::meet(value, rhs); assert(ValueT::meet(newValue, value) == newValue && @@ -155,7 +156,8 @@ class Lattice : public AbstractSparseLattice { return ChangeResult::Change; } - template + template ::value> * = nullptr> ChangeResult meet(const VT &rhs) { return ChangeResult::NoChange; } diff --git a/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp index e1c60f06a6b5eb..6b35d4e2c0d8af 100644 --- a/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp @@ -18,18 +18,27 @@ using namespace mlir::dataflow; namespace { -/// This lattice represents, for a given value, the set of memory resources that -/// this value, or anything derived from this value, is potentially written to. -struct WrittenTo : public AbstractSparseLattice { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WrittenTo) - using AbstractSparseLattice::AbstractSparseLattice; +/// Lattice value storing the a set of memory resources that something +/// is written to. +struct WrittenToLatticeValue { + bool operator==(const WrittenToLatticeValue &other) { + return this->writes == other.writes; + } - void print(raw_ostream &os) const override { - os << "["; - llvm::interleave( - writes, os, [&](const StringAttr &a) { os << a.str(); }, " "); - os << "]"; + static WrittenToLatticeValue meet(const WrittenToLatticeValue &lhs, + const WrittenToLatticeValue &rhs) { + WrittenToLatticeValue res = lhs; + (void)res.addWrites(rhs.writes); + + return res; } + + static WrittenToLatticeValue join(const WrittenToLatticeValue &lhs, + const WrittenToLatticeValue &rhs) { + // Should not be triggered by this test, but required by `Lattice` + assert(false); + } + ChangeResult addWrites(const SetVector &writes) { int sizeBefore = this->writes.size(); this->writes.insert(writes.begin(), writes.end()); @@ -37,14 +46,26 @@ struct WrittenTo : public AbstractSparseLattice { return sizeBefore == sizeAfter ? ChangeResult::NoChange : ChangeResult::Change; } - ChangeResult meet(const AbstractSparseLattice &other) override { - const auto *rhs = reinterpret_cast(&other); - return addWrites(rhs->writes); + + void print(raw_ostream &os) const { + os << "["; + llvm::interleave( + writes, os, [&](const StringAttr &a) { os << a.str(); }, " "); + os << "]"; } + void clear() { writes.clear(); } + SetVector writes; }; +/// This lattice represents, for a given value, the set of memory resources that +/// this value, or anything derived from this value, is potentially written to. +struct WrittenTo : public Lattice { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WrittenTo) + using Lattice::Lattice; +}; + /// An analysis that, by going backwards along the dataflow graph, annotates /// each value with all the memory resources it (or anything derived from it) /// is eventually written to. @@ -65,7 +86,9 @@ class WrittenToAnalysis : public SparseBackwardDataFlowAnalysis { void visitExternalCall(CallOpInterface call, ArrayRef operands, ArrayRef results) override; - void setToExitState(WrittenTo *lattice) override { lattice->writes.clear(); } + void setToExitState(WrittenTo *lattice) override { + lattice->getValue().clear(); + } private: bool assumeFuncWrites; @@ -77,7 +100,8 @@ void WrittenToAnalysis::visitOperation(Operation *op, if (auto store = dyn_cast(op)) { SetVector newWrites; newWrites.insert(op->getAttrOfType("tag_name")); - propagateIfChanged(operands[0], operands[0]->addWrites(newWrites)); + propagateIfChanged(operands[0], + operands[0]->getValue().addWrites(newWrites)); return; } // By default, every result of an op depends on every operand. for (const WrittenTo *r : results) { @@ -95,7 +119,7 @@ void WrittenToAnalysis::visitBranchOperand(OpOperand &operand) { newWrites.insert( StringAttr::get(operand.getOwner()->getContext(), "brancharg" + Twine(operand.getOperandNumber()))); - propagateIfChanged(lattice, lattice->addWrites(newWrites)); + propagateIfChanged(lattice, lattice->getValue().addWrites(newWrites)); } void WrittenToAnalysis::visitCallOperand(OpOperand &operand) { @@ -105,7 +129,7 @@ void WrittenToAnalysis::visitCallOperand(OpOperand &operand) { newWrites.insert( StringAttr::get(operand.getOwner()->getContext(), "callarg" + Twine(operand.getOperandNumber()))); - propagateIfChanged(lattice, lattice->addWrites(newWrites)); + propagateIfChanged(lattice, lattice->getValue().addWrites(newWrites)); } void WrittenToAnalysis::visitExternalCall(CallOpInterface call, @@ -124,7 +148,7 @@ void WrittenToAnalysis::visitExternalCall(CallOpInterface call, call.getOperation()->getName().getStringRef()); } newWrites.insert(name); - propagateIfChanged(lattice, lattice->addWrites(newWrites)); + propagateIfChanged(lattice, lattice->getValue().addWrites(newWrites)); } } ``````````
andidr commented 6 months ago

CC @matthiaskramm as the original contributor of has_meet

matthiaskramm commented 6 months ago

Looks good!

Not a blocker, but after this change, we don't actually have any tests that verify that inheriting directly from AbstractSparseLattice (instead of from Lattice<X>) works as expected?

andidr commented 6 months ago

@matthiaskramm Thanks for the review! Indeed, a merge of this change leaves no test with meet directly provided by the lattice. However, duplicating the test to restore the original behavior seems overly bulky and parametrization would make the test overly convoluted. Though, if you prefer any of these solutions over the current result, I'd be happy to provide an implementation and to amend the PR.

andidr commented 6 months ago

@matthiaskramm Any thoughts about the options for the tests? If you are fine with the current state, maybe someone with write access could go ahead and merge the changes? Thanks!

matthiaskramm commented 6 months ago

I'm OK with the current state. From my side, this is fine to merge in.

andidr commented 5 months ago

Can someone with commits rights merge (or comment if anything needs to be changed)? Maybe @ftynse?

andidr commented 5 months ago

Given that most of the contribution ins mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h are from @Mogball, maybe @Mogball may consider merging? Thanks!

ftynse commented 5 months ago

Apologies for the delay.