aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTobias Gysi <gysit@google.com>2021-04-20 11:26:44 +0000
committerTobias Gysi <gysit@google.com>2021-04-20 11:55:44 +0000
commitb9715156ff909fb38725893afb1d18709cb7f1bd (patch)
treebf51e59cdf731d0cec3a7d4b43605aee1b9f5445
parent[DAG] SelectionDAG.cpp - breakup if-else chains where each block returns. NFCI. (diff)
downloadllvm-project-b9715156ff909fb38725893afb1d18709cb7f1bd.tar.gz
llvm-project-b9715156ff909fb38725893afb1d18709cb7f1bd.tar.bz2
llvm-project-b9715156ff909fb38725893afb1d18709cb7f1bd.zip
[mlir][linalg] lower index operations during linalg to vector lowering.
The patch extends the vectorization pass to lower linalg index operations to vector code. It allocates constant 1d vectors that enumerate the indexes along the iteration dimensions and broadcasts/transposes these 1d vectors to the iteration space. Differential Revision: https://reviews.llvm.org/D100373
-rw-r--r--mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td10
-rw-r--r--mlir/include/mlir/IR/Builders.h1
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp23
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp3
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp58
-rw-r--r--mlir/lib/IR/Builders.cpp6
-rw-r--r--mlir/test/Dialect/Linalg/vectorization.mlir60
7 files changed, 134 insertions, 27 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 6c3e86c6d2f1..0512b351650e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -1242,11 +1242,21 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/// appear in the operands.
SmallVector<Value, 4> createFlatListOfOperandDims(OpBuilder &, Location);
+ /// Return the flat list of all operands' static dimension sizes in the
+ /// order they appear in the operands. All operand dimension sizes have to
+ /// be statically known.
+ SmallVector<int64_t, 4> createFlatListOfOperandStaticDims();
+
/// Create the loop ranges to materialize the computation over the current
/// operands. This is done by applying `getShapesToLoopsMap` to
/// `createFlatListOfOperandDims`.
SmallVector<Range, 4> createLoopRanges(OpBuilder &b, Location loc);
+ /// Compute the static loop sizes necessary to vectorize the computation.
+ /// This is done by applying `getShapesToLoopsMap` to
+ /// `createFlatListOfOperandStaticDims`.
+ SmallVector<int64_t, 4> computeStaticLoopSizes();
+
/// Returns all the operands past the inputs, output_buffers and
/// init_tensors operands. Asserts that these operands are value types to
/// allow transformations like tiling to just use the values when cloning
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index f8b119cf962a..1e0863c7a7a4 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -124,6 +124,7 @@ public:
DenseIntElementsAttr getBoolVectorAttr(ArrayRef<bool> values);
DenseIntElementsAttr getI32VectorAttr(ArrayRef<int32_t> values);
DenseIntElementsAttr getI64VectorAttr(ArrayRef<int64_t> values);
+ DenseIntElementsAttr getIndexVectorAttr(ArrayRef<int64_t> values);
/// Tensor-typed DenseIntElementsAttr getters. `values` can be empty.
/// These are generally preferable for representing general lists of integers
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index f1bf22cf3d69..1c45467afbb4 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -193,6 +193,16 @@ SmallVector<Value, 4> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
return res;
}
+SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() {
+ SmallVector<int64_t, 4> res;
+ for (Value v : getShapedOperands()) {
+ ShapedType t = v.getType().template cast<ShapedType>();
+ assert(t.hasStaticShape() && "expected operands to have static shapes");
+ llvm::append_range(res, t.getShape());
+ }
+ return res;
+}
+
SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
AffineMap map = getLoopsToShapesMap();
unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
@@ -211,6 +221,19 @@ SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
return res;
}
+SmallVector<int64_t, 4> LinalgOp::computeStaticLoopSizes() {
+ AffineMap map = getLoopsToShapesMap();
+ unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
+ SmallVector<int64_t, 4> allShapeSizes = createFlatListOfOperandStaticDims();
+ SmallVector<int64_t, 4> res(numDims, 0);
+ for (unsigned idx = 0; idx < numRes; ++idx) {
+ auto result = map.getResult(idx);
+ if (auto d = result.dyn_cast<AffineDimExpr>())
+ res[d.getPosition()] = allShapeSizes[idx];
+ }
+ return res;
+}
+
/// Visitor to check if any of the given set of positions from AffineDimExprs
/// are used within an AffineExpr.
struct HasAffineDimExprVisitor
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 55402a737cbb..2e8b1580c5c7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -462,8 +462,7 @@ mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const {
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
- // TODO: remove hasIndexSemantics check once index ops are supported.
- if (!linalgOp || linalgOp.hasIndexSemantics())
+ if (!linalgOp)
return failure();
if (failed(filter.checkAndNotify(rewriter, linalgOp)))
return failure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c14a3b3628ba..14ef418ed591 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -166,6 +166,42 @@ vectorizeLinalgYield(OpBuilder &builder, Operation *op,
return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
}
+/// Helper function to vectorize the index operations of a `linalgOp`. Return
+/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
+/// should map the produced operations. This function is meant to be used as a
+/// CustomVectorizationHook.
+static VectorizationResult
+vectorizeLinalgIndex(OpBuilder &builder, Operation *op, LinalgOp linalgOp) {
+ IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
+ if (!indexOp)
+ return VectorizationResult{VectorizationStatus::Failure, nullptr};
+ auto loc = indexOp.getLoc();
+ // Compute the static loop sizes of the index op.
+ auto targetShape = linalgOp.computeStaticLoopSizes();
+ // Compute a one-dimensional index vector for the index op dimension.
+ SmallVector<int64_t> constantSeq(
+ llvm::seq<int64_t>(0, targetShape[indexOp.dim()]));
+ ConstantOp constantOp =
+ builder.create<ConstantOp>(loc, builder.getIndexVectorAttr(constantSeq));
+ // Return the one-dimensional index vector if it lives in the trailing
+ // dimension of the iteration space since the vectorization algorithm in this
+ // case can handle the broadcast.
+ if (indexOp.dim() == targetShape.size() - 1)
+ return VectorizationResult{VectorizationStatus::NewOp, constantOp};
+ // Otherwise permute the targetShape to move the index dimension last,
+ // broadcast the one-dimensional index vector to the permuted shape, and
+ // finally transpose the broadcasted index vector to undo the permutation.
+ std::swap(targetShape[indexOp.dim()], targetShape.back());
+ auto broadCastOp = builder.create<vector::BroadcastOp>(
+ loc, VectorType::get(targetShape, builder.getIndexType()), constantOp);
+ SmallVector<int64_t> transposition(
+ llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
+ std::swap(transposition.back(), transposition[indexOp.dim()]);
+ auto transposeOp =
+ builder.create<vector::TransposeOp>(loc, broadCastOp, transposition);
+ return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
+}
+
/// Generic vectorization for a single operation `op`, given already vectorized
/// operands carried by `bvm`. Vectorization occurs as follows:
/// 1. Try to apply any of the `customVectorizationHooks` and return its
@@ -245,7 +281,7 @@ static bool hasOnlyScalarElementwiseOp(Region &r) {
if (!llvm::hasSingleElement(r))
return false;
for (Operation &op : r.front()) {
- if (!(isa<ConstantOp, linalg::YieldOp>(op) ||
+ if (!(isa<ConstantOp, linalg::YieldOp, linalg::IndexOp>(op) ||
OpTrait::hasElementwiseMappableTraits(&op)) ||
llvm::any_of(op.getResultTypes(),
[](Type type) { return !type.isIntOrIndexOrFloat(); }))
@@ -293,7 +329,9 @@ static AffineMap getTransferReadMap(LinalgOp linalgOp, unsigned argIndex) {
/// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
/// load).
/// TODO: Reuse opportunities for RAR dependencies.
-/// 4. Register CustomVectorizationHook for YieldOp to capture the results.
+/// 4a. Register CustomVectorizationHook for YieldOp to capture the results.
+/// 4b. Register CustomVectorizationHook for IndexOp to access the iteration
+/// indices.
/// 5. Iteratively call vectorizeOneOp on the region operations.
LogicalResult vectorizeAsLinalgGeneric(
OpBuilder &builder, LinalgOp linalgOp, SmallVectorImpl<Value> &newResults,
@@ -333,16 +371,23 @@ LogicalResult vectorizeAsLinalgGeneric(
bvm.map(vectorArg, vectorRead);
}
- // 4. Register CustomVectorizationHook for yieldOp.
+ auto hooks = llvm::to_vector<4>(customVectorizationHooks);
+ // 4a. Register CustomVectorizationHook for yieldOp.
CustomVectorizationHook vectorizeYield =
[&](Operation *op,
const BlockAndValueMapping &bvm) -> VectorizationResult {
return vectorizeLinalgYield(builder, op, bvm, linalgOp, newResults);
};
- // Append the vectorizeYield hook.
- auto hooks = llvm::to_vector<4>(customVectorizationHooks);
hooks.push_back(vectorizeYield);
+ // 4b. Register CustomVectorizationHook for indexOp.
+ CustomVectorizationHook vectorizeIndex =
+ [&](Operation *op,
+ const BlockAndValueMapping &bvm) -> VectorizationResult {
+ return vectorizeLinalgIndex(builder, op, linalgOp);
+ };
+ hooks.push_back(vectorizeIndex);
+
// 5. Iteratively call `vectorizeOneOp` to each op in the slice.
for (Operation &op : block.getOperations()) {
VectorizationResult result = vectorizeOneOp(builder, &op, bvm, hooks);
@@ -401,9 +446,6 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
for (Type outputTensorType : linalgOp.getOutputTensorTypes())
if (!outputTensorType.cast<ShapedType>().hasStaticShape())
return failure();
- // TODO: remove once index ops are supported.
- if (linalgOp.hasIndexSemantics())
- return failure();
if (isElementwise(op))
return success();
return success(isaContractionOpInterface(linalgOp));
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index d1ab3795d00e..4f8aa9e82075 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -120,6 +120,12 @@ DenseIntElementsAttr Builder::getI64VectorAttr(ArrayRef<int64_t> values) {
values);
}
+DenseIntElementsAttr Builder::getIndexVectorAttr(ArrayRef<int64_t> values) {
+ return DenseIntElementsAttr::get(
+ VectorType::get(static_cast<int64_t>(values.size()), getIndexType()),
+ values);
+}
+
DenseIntElementsAttr Builder::getI32TensorAttr(ArrayRef<int32_t> values) {
return DenseIntElementsAttr::get(
RankedTensorType::get(static_cast<int64_t>(values.size()),
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index faaadcf94a5c..c18bf5b5cd8b 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -174,6 +174,49 @@ func @test_vectorize_copy_scalar(%A : memref<f32>, %B : memref<f32>) {
// -----
+// CHECK-LABEL: func @test_vectorize_trailing_index
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<1x2x4x8xindex>)
+func @test_vectorize_trailing_index(%arg0: memref<1x2x4x8xindex>) {
+ // CHECK-DAG: %[[CST0:.*]] = constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
+ // CHECK-DAG: %[[C0:.*]] = constant 0 : index
+ linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ outs(%arg0: memref<1x2x4x8xindex>) {
+ ^bb0(%arg1: index):
+ // CHECK: %[[BCST:.*]] = vector.broadcast %[[CST0]] : vector<8xindex> to vector<1x2x4x8xindex>
+ // CHECK: vector.transfer_write %[[BCST]], %[[ARG0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {{.*}} : vector<1x2x4x8xindex>, memref<1x2x4x8xindex>
+ %0 = linalg.index 3 : index
+ linalg.yield %0 : index
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @test_vectorize_inner_index
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<1x2x4x8xindex>)
+func @test_vectorize_inner_index(%arg0: memref<1x2x4x8xindex>) {
+ // CHECK-DAG: %[[CST0:.*]] = constant dense<[0, 1]> : vector<2xindex>
+ // CHECK-DAG: %[[C0:.*]] = constant 0 : index
+ linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ outs(%arg0: memref<1x2x4x8xindex>) {
+ ^bb0(%arg1: index):
+ // CHECK: %[[BCST:.*]] = vector.broadcast %[[CST0]] : vector<2xindex> to vector<1x8x4x2xindex>
+ // CHECK: %[[TRAN:.*]] = vector.transpose %[[BCST]], [0, 3, 2, 1] : vector<1x8x4x2xindex> to vector<1x2x4x8xindex>
+ // CHECK: vector.transfer_write %[[TRAN]], %[[ARG0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {{.*}} : vector<1x2x4x8xindex>, memref<1x2x4x8xindex>
+ %0 = linalg.index 1 : index
+ linalg.yield %0 : index
+ }
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @generic_vectorize
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x256xf32>, %[[ARG1:.*]]: memref<4x256xf32>,
// CHECK-SAME: %[[ARG2:.*]]: memref<256xf32>, %[[ARG3:.*]]: f32)
@@ -252,7 +295,6 @@ func @generic_vectorize(%arg0: memref<4x256xf32>,
return
}
-
// -----
// CHECK-LABEL: func @generic_vectorize_tensor
@@ -469,19 +511,3 @@ func @pad_dynamic(%arg0: tensor<1x2x2x?xf32>, %low: index, %high: index,
} : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32>
return %0 : tensor<6x?x?x?xf32>
}
-
-// -----
-
-// CHECK-LABEL: @index_op
-// CHECK: linalg.generic
-func @index_op(%arg0: memref<4x8xindex>) {
- linalg.generic {
- indexing_maps = [affine_map<(i, j) -> (i, j)>],
- iterator_types = ["parallel", "parallel"]}
- outs(%arg0 : memref<4x8xindex>) {
- ^bb0(%arg1: index): // no predecessors
- %0 = linalg.index 1 : index
- linalg.yield %0 : index
- }
- return
-}