2022-08-18 12:49:37 -07:00
|
|
|
#include "mlir/Analysis/SliceAnalysis.h"
|
|
|
|
#include "mlir/Dialect/SCF/SCF.h"
|
|
|
|
#include "mlir/IR/BlockAndValueMapping.h"
|
2022-05-16 19:17:15 +08:00
|
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
|
|
#include "mlir/IR/Matchers.h"
|
|
|
|
#include "mlir/IR/PatternMatch.h"
|
2022-08-18 12:49:37 -07:00
|
|
|
#include "mlir/IR/Verifier.h"
|
2022-10-16 21:19:42 -07:00
|
|
|
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
2022-05-16 19:17:15 +08:00
|
|
|
#include "mlir/Pass/Pass.h"
|
2022-10-16 21:19:42 -07:00
|
|
|
#include "mlir/Pass/PassManager.h"
|
2022-05-16 19:17:15 +08:00
|
|
|
#include "mlir/Support/LogicalResult.h"
|
|
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
2022-10-16 21:19:42 -07:00
|
|
|
#include "mlir/Transforms/Passes.h"
|
|
|
|
#include "mlir/Transforms/RegionUtils.h"
|
2022-11-10 13:57:27 +08:00
|
|
|
#include "triton/Analysis/Utility.h"
|
2022-05-16 19:17:15 +08:00
|
|
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
|
|
|
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
2022-11-23 03:08:23 +08:00
|
|
|
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
2022-05-16 19:17:15 +08:00
|
|
|
|
|
|
|
#include <memory>
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
|
|
#include "TritonGPUCombine.inc"
|
2022-08-18 12:49:37 -07:00
|
|
|
|
|
|
|
// -----------------------------------------------------------------------------
|
|
|
|
//
|
|
|
|
// -----------------------------------------------------------------------------
|
|
|
|
|
2022-10-21 16:52:15 -07:00
|
|
|
// convert(blocked, dot_operand) ->
|
|
|
|
// convert(blocked, mma) + convert(mma, dot_operand)
|
|
|
|
// if this value is itself the result of a dot operation
|
2022-11-14 10:15:53 +08:00
|
|
|
// this is a heuristic to accommodate some pattern seen in fused attention
|
2022-10-21 16:52:15 -07:00
|
|
|
// kernels.
|
|
|
|
// TODO: replace this by something more generic, i.e. layout-aware CSE
|
|
|
|
class DecomposeDotOperand : public mlir::RewritePattern {
|
|
|
|
|
|
|
|
public:
|
2022-12-12 17:46:16 +08:00
|
|
|
explicit DecomposeDotOperand(mlir::MLIRContext *context)
|
2022-10-21 16:52:15 -07:00
|
|
|
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
|
|
|
1, context) {}
|
|
|
|
|
|
|
|
mlir::LogicalResult
|
|
|
|
matchAndRewrite(mlir::Operation *op,
|
|
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
|
|
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(op))
|
|
|
|
return mlir::failure();
|
|
|
|
auto convert = llvm::cast<triton::gpu::ConvertLayoutOp>(op);
|
|
|
|
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
|
|
|
|
auto dstType = convert.getType().cast<RankedTensorType>();
|
|
|
|
if (srcType.getEncoding().isa<triton::gpu::BlockedEncodingAttr>() &&
|
|
|
|
dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>()) {
|
2022-12-01 11:54:18 -08:00
|
|
|
auto dstDotOperand =
|
|
|
|
dstType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>();
|
2022-11-30 11:13:24 +01:00
|
|
|
auto dstParent = dstDotOperand.getParent();
|
2022-12-01 11:54:18 -08:00
|
|
|
if (dstDotOperand.getOpIdx() == 1 ||
|
|
|
|
!dstParent.isa<triton::gpu::MmaEncodingAttr>())
|
2022-11-30 11:13:24 +01:00
|
|
|
return mlir::failure();
|
|
|
|
auto dstParentMma = dstParent.cast<triton::gpu::MmaEncodingAttr>();
|
2022-12-01 11:54:18 -08:00
|
|
|
if (dstParentMma.getVersion() == 1 ||
|
|
|
|
dstParentMma.getWarpsPerCTA()[1] > 1)
|
2022-11-30 11:13:24 +01:00
|
|
|
return mlir::failure();
|
2022-12-01 11:54:18 -08:00
|
|
|
SetVector<Operation *> bwdSlices;
|
2022-11-30 11:13:24 +01:00
|
|
|
mlir::getBackwardSlice(convert.getResult(), &bwdSlices);
|
2022-12-01 11:54:18 -08:00
|
|
|
if (llvm::find_if(bwdSlices, [](Operation *op) {
|
|
|
|
return isa<triton::DotOp>(op);
|
|
|
|
}) == bwdSlices.end())
|
2022-11-30 11:13:24 +01:00
|
|
|
return mlir::failure();
|
2022-12-01 11:54:18 -08:00
|
|
|
|
|
|
|
auto tmpType = RankedTensorType::get(
|
|
|
|
dstType.getShape(), dstType.getElementType(), dstParentMma);
|
2022-10-21 16:52:15 -07:00
|
|
|
auto tmp = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
|
|
|
convert.getLoc(), tmpType, convert.getOperand());
|
|
|
|
auto newConvert = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
|
|
|
convert.getLoc(), dstType, tmp);
|
|
|
|
rewriter.replaceOp(op, {newConvert});
|
|
|
|
return mlir::success();
|
|
|
|
}
|
|
|
|
return mlir::failure();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-08-18 12:49:37 -07:00
|
|
|
// Layout conversions can't deduce their return type automatically.
|
|
|
|
// IIUC they are therefore not handled by DRR right now
|
|
|
|
class SimplifyConversion : public mlir::RewritePattern {
|
|
|
|
public:
|
2022-12-12 17:46:16 +08:00
|
|
|
explicit SimplifyConversion(mlir::MLIRContext *context)
|
2022-08-18 12:49:37 -07:00
|
|
|
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
2022-10-16 21:19:42 -07:00
|
|
|
4, context) {}
|
2022-08-18 12:49:37 -07:00
|
|
|
|
|
|
|
mlir::LogicalResult
|
|
|
|
matchAndRewrite(mlir::Operation *op,
|
|
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
|
|
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(op))
|
|
|
|
return mlir::failure();
|
2022-10-21 16:52:15 -07:00
|
|
|
auto convert = llvm::cast<triton::gpu::ConvertLayoutOp>(op);
|
|
|
|
// we don't handle conversions to DotOperandEncodingAttr
|
2022-11-14 10:15:53 +08:00
|
|
|
// this is a heuristics to accommodate fused attention
|
2022-11-30 11:13:24 +01:00
|
|
|
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
|
|
|
|
auto dstType = convert.getType().cast<RankedTensorType>();
|
|
|
|
if (dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>() &&
|
|
|
|
srcType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
|
|
|
|
return mlir::failure();
|
2022-08-18 12:49:37 -07:00
|
|
|
// convert to the same layout -- we can delete
|
|
|
|
if (op->getResultTypes() == op->getOperandTypes()) {
|
|
|
|
rewriter.replaceOp(op, op->getOperands());
|
|
|
|
return mlir::success();
|
|
|
|
}
|
|
|
|
Operation *arg = op->getOperand(0).getDefiningOp();
|
|
|
|
// block argument
|
|
|
|
if (!arg)
|
|
|
|
return mlir::failure();
|
2022-10-20 19:03:37 -07:00
|
|
|
// cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2)
|
|
|
|
auto alloc_tensor = dyn_cast<triton::gpu::AllocTensorOp>(arg);
|
|
|
|
if (alloc_tensor) {
|
[WIP][Triton-MLIR] Prefetch pass fixup (#873)
A (potential) problem by directly adopting `tensor.extract_slice`.
Long story short, `tensor.extract_slice` is not aware of swizzling.
Consider the following shared memory tensor and its first three slices,
where each slice includes two tile (the loading unit of LDGSTS) of
elements. Currently, the tiles haven't been swizzled yet, so slicing
seems to work.
<img width="1219" alt="image"
src="https://user-images.githubusercontent.com/2306281/201833023-a7950705-2d50-4c0a-8527-7505261c3a3c.png">
However, now consider the following figure, which is the layout after
applying swizzling on the first figure.
<img width="1244" alt="image"
src="https://user-images.githubusercontent.com/2306281/201834824-7daae360-f5bc-4e6b-a921-20be3f294b78.png">
Note that on phase 2, all tiles have been swizzled out of their
originally slices. This implies that if we use the tile index after
slicing, we can no longer locate the correct tiles. For example, T3 was
in slice 1 but got swapped to slice 0 after swizzling.
Here's a more detailed explanation. In the current `triton-mlir` branch,
we only compute the relative offset of each tile. So T3's index in Slice
1 is *1*, and it will be swizzled using *1* and *phase id*. Whereas the
correct index of T3 should be *3*, which is the relative offset to the
beginning of the shared memory tensor being swizzled, and T3 should be
swizzled using *3* and *phase id*.
This PR proposes a hacky solution for this problem. We restore the
"correct" offset of each tile by **assuming that slicing on a specific
dim only happens at most once on the output of insert_slice_async**. I
admit it's risky and fragile.
The other possible solution is adopting cutlass' swizzling logic that
limits the indices being swizzled in a "bounding box" that matches the
mma instruction executes. For example, in the following tensor layout,
each 4x4 submatrix is a minimum swizzling unit, and the entire tensor
represents the tensor layout of operand A in `mma.16816`.
<img width="565" alt="image"
src="https://user-images.githubusercontent.com/2306281/201836879-4ca7824b-530c-4a06-a3d5-1e74a2de1b42.png">
Co-authored-by: Phil Tillet <phil@openai.com>
2022-11-19 19:57:16 -08:00
|
|
|
if (!isSharedEncoding(op->getResult(0))) {
|
|
|
|
return mlir::failure();
|
|
|
|
}
|
2022-10-20 19:03:37 -07:00
|
|
|
rewriter.replaceOpWithNewOp<triton::gpu::AllocTensorOp>(
|
|
|
|
op, op->getResult(0).getType());
|
|
|
|
return mlir::success();
|
|
|
|
}
|
2022-10-27 22:09:06 -07:00
|
|
|
// cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2))
|
2022-10-20 19:03:37 -07:00
|
|
|
auto insert_slice = dyn_cast<triton::gpu::InsertSliceAsyncOp>(arg);
|
|
|
|
if (insert_slice) {
|
[WIP][Triton-MLIR] Prefetch pass fixup (#873)
A (potential) problem by directly adopting `tensor.extract_slice`.
Long story short, `tensor.extract_slice` is not aware of swizzling.
Consider the following shared memory tensor and its first three slices,
where each slice includes two tile (the loading unit of LDGSTS) of
elements. Currently, the tiles haven't been swizzled yet, so slicing
seems to work.
<img width="1219" alt="image"
src="https://user-images.githubusercontent.com/2306281/201833023-a7950705-2d50-4c0a-8527-7505261c3a3c.png">
However, now consider the following figure, which is the layout after
applying swizzling on the first figure.
<img width="1244" alt="image"
src="https://user-images.githubusercontent.com/2306281/201834824-7daae360-f5bc-4e6b-a921-20be3f294b78.png">
Note that on phase 2, all tiles have been swizzled out of their
originally slices. This implies that if we use the tile index after
slicing, we can no longer locate the correct tiles. For example, T3 was
in slice 1 but got swapped to slice 0 after swizzling.
Here's a more detailed explanation. In the current `triton-mlir` branch,
we only compute the relative offset of each tile. So T3's index in Slice
1 is *1*, and it will be swizzled using *1* and *phase id*. Whereas the
correct index of T3 should be *3*, which is the relative offset to the
beginning of the shared memory tensor being swizzled, and T3 should be
swizzled using *3* and *phase id*.
This PR proposes a hacky solution for this problem. We restore the
"correct" offset of each tile by **assuming that slicing on a specific
dim only happens at most once on the output of insert_slice_async**. I
admit it's risky and fragile.
The other possible solution is adopting cutlass' swizzling logic that
limits the indices being swizzled in a "bounding box" that matches the
mma instruction executes. For example, in the following tensor layout,
each 4x4 submatrix is a minimum swizzling unit, and the entire tensor
represents the tensor layout of operand A in `mma.16816`.
<img width="565" alt="image"
src="https://user-images.githubusercontent.com/2306281/201836879-4ca7824b-530c-4a06-a3d5-1e74a2de1b42.png">
Co-authored-by: Phil Tillet <phil@openai.com>
2022-11-19 19:57:16 -08:00
|
|
|
if (!isSharedEncoding(op->getResult(0))) {
|
|
|
|
return mlir::failure();
|
|
|
|
}
|
2022-11-06 22:59:03 -08:00
|
|
|
auto newType = op->getResult(0).getType().cast<RankedTensorType>();
|
2022-10-27 22:09:06 -07:00
|
|
|
// Ensure that the new insert_slice op is placed in the same place as the
|
|
|
|
// old insert_slice op. Otherwise, the new insert_slice op may be placed
|
|
|
|
// after the async_wait op, which is not allowed.
|
|
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
|
|
rewriter.setInsertionPoint(insert_slice);
|
2022-11-06 22:59:03 -08:00
|
|
|
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
2022-10-20 19:03:37 -07:00
|
|
|
op->getLoc(), newType, insert_slice.dst());
|
|
|
|
rewriter.replaceOpWithNewOp<triton::gpu::InsertSliceAsyncOp>(
|
2022-11-06 22:59:03 -08:00
|
|
|
op, newType, insert_slice.src(), newArg.getResult(),
|
2022-10-20 19:03:37 -07:00
|
|
|
insert_slice.index(), insert_slice.mask(), insert_slice.other(),
|
2022-11-10 13:57:27 +08:00
|
|
|
insert_slice.cache(), insert_slice.evict(), insert_slice.isVolatile(),
|
|
|
|
insert_slice.axis());
|
2022-10-20 19:03:37 -07:00
|
|
|
return mlir::success();
|
|
|
|
}
|
2022-11-06 22:59:03 -08:00
|
|
|
// cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2))
|
|
|
|
auto extract_slice = dyn_cast<tensor::ExtractSliceOp>(arg);
|
2022-10-20 19:03:37 -07:00
|
|
|
if (extract_slice) {
|
[WIP][Triton-MLIR] Prefetch pass fixup (#873)
A (potential) problem by directly adopting `tensor.extract_slice`.
Long story short, `tensor.extract_slice` is not aware of swizzling.
Consider the following shared memory tensor and its first three slices,
where each slice includes two tile (the loading unit of LDGSTS) of
elements. Currently, the tiles haven't been swizzled yet, so slicing
seems to work.
<img width="1219" alt="image"
src="https://user-images.githubusercontent.com/2306281/201833023-a7950705-2d50-4c0a-8527-7505261c3a3c.png">
However, now consider the following figure, which is the layout after
applying swizzling on the first figure.
<img width="1244" alt="image"
src="https://user-images.githubusercontent.com/2306281/201834824-7daae360-f5bc-4e6b-a921-20be3f294b78.png">
Note that on phase 2, all tiles have been swizzled out of their
originally slices. This implies that if we use the tile index after
slicing, we can no longer locate the correct tiles. For example, T3 was
in slice 1 but got swapped to slice 0 after swizzling.
Here's a more detailed explanation. In the current `triton-mlir` branch,
we only compute the relative offset of each tile. So T3's index in Slice
1 is *1*, and it will be swizzled using *1* and *phase id*. Whereas the
correct index of T3 should be *3*, which is the relative offset to the
beginning of the shared memory tensor being swizzled, and T3 should be
swizzled using *3* and *phase id*.
This PR proposes a hacky solution for this problem. We restore the
"correct" offset of each tile by **assuming that slicing on a specific
dim only happens at most once on the output of insert_slice_async**. I
admit it's risky and fragile.
The other possible solution is adopting cutlass' swizzling logic that
limits the indices being swizzled in a "bounding box" that matches the
mma instruction executes. For example, in the following tensor layout,
each 4x4 submatrix is a minimum swizzling unit, and the entire tensor
represents the tensor layout of operand A in `mma.16816`.
<img width="565" alt="image"
src="https://user-images.githubusercontent.com/2306281/201836879-4ca7824b-530c-4a06-a3d5-1e74a2de1b42.png">
Co-authored-by: Phil Tillet <phil@openai.com>
2022-11-19 19:57:16 -08:00
|
|
|
if (!isSharedEncoding(op->getResult(0))) {
|
|
|
|
return mlir::failure();
|
|
|
|
}
|
2022-11-06 22:59:03 -08:00
|
|
|
auto origType = extract_slice.source().getType().cast<RankedTensorType>();
|
|
|
|
auto newType = RankedTensorType::get(
|
|
|
|
origType.getShape(), origType.getElementType(),
|
|
|
|
op->getResult(0).getType().cast<RankedTensorType>().getEncoding());
|
2022-11-10 13:57:27 +08:00
|
|
|
auto origResType = op->getResult(0).getType().cast<RankedTensorType>();
|
|
|
|
auto resType = RankedTensorType::get(
|
|
|
|
origResType.getShape(), origResType.getElementType(),
|
|
|
|
extract_slice.getType().cast<RankedTensorType>().getEncoding());
|
2022-10-27 22:09:06 -07:00
|
|
|
// Ensure that the new extract_slice op is placed in the same place as the
|
|
|
|
// old extract_slice op. Otherwise, the new extract_slice op may be placed
|
|
|
|
// after the async_wait op, which is not allowed.
|
|
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
|
|
rewriter.setInsertionPoint(extract_slice);
|
2022-11-06 22:59:03 -08:00
|
|
|
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
|
|
|
op->getLoc(), newType, extract_slice.source());
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
|
|
|
|
op, resType, newArg.getResult(), extract_slice.offsets(),
|
|
|
|
extract_slice.sizes(), extract_slice.strides(),
|
|
|
|
extract_slice.static_offsets(), extract_slice.static_sizes(),
|
|
|
|
extract_slice.static_strides());
|
2022-10-20 19:03:37 -07:00
|
|
|
return mlir::success();
|
|
|
|
}
|
2022-11-10 13:57:27 +08:00
|
|
|
|
[WIP][Triton-MLIR] Prefetch pass fixup (#873)
A (potential) problem by directly adopting `tensor.extract_slice`.
Long story short, `tensor.extract_slice` is not aware of swizzling.
Consider the following shared memory tensor and its first three slices,
where each slice includes two tile (the loading unit of LDGSTS) of
elements. Currently, the tiles haven't been swizzled yet, so slicing
seems to work.
<img width="1219" alt="image"
src="https://user-images.githubusercontent.com/2306281/201833023-a7950705-2d50-4c0a-8527-7505261c3a3c.png">
However, now consider the following figure, which is the layout after
applying swizzling on the first figure.
<img width="1244" alt="image"
src="https://user-images.githubusercontent.com/2306281/201834824-7daae360-f5bc-4e6b-a921-20be3f294b78.png">
Note that on phase 2, all tiles have been swizzled out of their
originally slices. This implies that if we use the tile index after
slicing, we can no longer locate the correct tiles. For example, T3 was
in slice 1 but got swapped to slice 0 after swizzling.
Here's a more detailed explanation. In the current `triton-mlir` branch,
we only compute the relative offset of each tile. So T3's index in Slice
1 is *1*, and it will be swizzled using *1* and *phase id*. Whereas the
correct index of T3 should be *3*, which is the relative offset to the
beginning of the shared memory tensor being swizzled, and T3 should be
swizzled using *3* and *phase id*.
This PR proposes a hacky solution for this problem. We restore the
"correct" offset of each tile by **assuming that slicing on a specific
dim only happens at most once on the output of insert_slice_async**. I
admit it's risky and fragile.
The other possible solution is adopting cutlass' swizzling logic that
limits the indices being swizzled in a "bounding box" that matches the
mma instruction executes. For example, in the following tensor layout,
each 4x4 submatrix is a minimum swizzling unit, and the entire tensor
represents the tensor layout of operand A in `mma.16816`.
<img width="565" alt="image"
src="https://user-images.githubusercontent.com/2306281/201836879-4ca7824b-530c-4a06-a3d5-1e74a2de1b42.png">
Co-authored-by: Phil Tillet <phil@openai.com>
2022-11-19 19:57:16 -08:00
|
|
|
// cvt(cvt(x, type1), type2) -> cvt(x, type2)
|
2022-08-18 12:49:37 -07:00
|
|
|
if (llvm::isa<triton::gpu::ConvertLayoutOp>(arg)) {
|
2022-11-10 13:57:27 +08:00
|
|
|
if (arg->getOperand(0).getDefiningOp() &&
|
[WIP][Triton-MLIR] Prefetch pass fixup (#873)
A (potential) problem by directly adopting `tensor.extract_slice`.
Long story short, `tensor.extract_slice` is not aware of swizzling.
Consider the following shared memory tensor and its first three slices,
where each slice includes two tile (the loading unit of LDGSTS) of
elements. Currently, the tiles haven't been swizzled yet, so slicing
seems to work.
<img width="1219" alt="image"
src="https://user-images.githubusercontent.com/2306281/201833023-a7950705-2d50-4c0a-8527-7505261c3a3c.png">
However, now consider the following figure, which is the layout after
applying swizzling on the first figure.
<img width="1244" alt="image"
src="https://user-images.githubusercontent.com/2306281/201834824-7daae360-f5bc-4e6b-a921-20be3f294b78.png">
Note that on phase 2, all tiles have been swizzled out of their
originally slices. This implies that if we use the tile index after
slicing, we can no longer locate the correct tiles. For example, T3 was
in slice 1 but got swapped to slice 0 after swizzling.
Here's a more detailed explanation. In the current `triton-mlir` branch,
we only compute the relative offset of each tile. So T3's index in Slice
1 is *1*, and it will be swizzled using *1* and *phase id*. Whereas the
correct index of T3 should be *3*, which is the relative offset to the
beginning of the shared memory tensor being swizzled, and T3 should be
swizzled using *3* and *phase id*.
This PR proposes a hacky solution for this problem. We restore the
"correct" offset of each tile by **assuming that slicing on a specific
dim only happens at most once on the output of insert_slice_async**. I
admit it's risky and fragile.
The other possible solution is adopting cutlass' swizzling logic that
limits the indices being swizzled in a "bounding box" that matches the
mma instruction executes. For example, in the following tensor layout,
each 4x4 submatrix is a minimum swizzling unit, and the entire tensor
represents the tensor layout of operand A in `mma.16816`.
<img width="565" alt="image"
src="https://user-images.githubusercontent.com/2306281/201836879-4ca7824b-530c-4a06-a3d5-1e74a2de1b42.png">
Co-authored-by: Phil Tillet <phil@openai.com>
2022-11-19 19:57:16 -08:00
|
|
|
!isSharedEncoding(arg->getOperand(0)) &&
|
|
|
|
isSharedEncoding(convert.getOperand()) &&
|
|
|
|
!isSharedEncoding(convert.getResult())) {
|
2022-11-10 13:57:27 +08:00
|
|
|
return mlir::failure();
|
|
|
|
}
|
2022-12-03 09:58:24 -08:00
|
|
|
if (isSharedEncoding(convert.getOperand()) &&
|
|
|
|
isSharedEncoding(convert.getResult())) {
|
|
|
|
return mlir::failure();
|
|
|
|
}
|
[WIP][Triton-MLIR] Prefetch pass fixup (#873)
A (potential) problem by directly adopting `tensor.extract_slice`.
Long story short, `tensor.extract_slice` is not aware of swizzling.
Consider the following shared memory tensor and its first three slices,
where each slice includes two tile (the loading unit of LDGSTS) of
elements. Currently, the tiles haven't been swizzled yet, so slicing
seems to work.
<img width="1219" alt="image"
src="https://user-images.githubusercontent.com/2306281/201833023-a7950705-2d50-4c0a-8527-7505261c3a3c.png">
However, now consider the following figure, which is the layout after
applying swizzling on the first figure.
<img width="1244" alt="image"
src="https://user-images.githubusercontent.com/2306281/201834824-7daae360-f5bc-4e6b-a921-20be3f294b78.png">
Note that on phase 2, all tiles have been swizzled out of their
originally slices. This implies that if we use the tile index after
slicing, we can no longer locate the correct tiles. For example, T3 was
in slice 1 but got swapped to slice 0 after swizzling.
Here's a more detailed explanation. In the current `triton-mlir` branch,
we only compute the relative offset of each tile. So T3's index in Slice
1 is *1*, and it will be swizzled using *1* and *phase id*. Whereas the
correct index of T3 should be *3*, which is the relative offset to the
beginning of the shared memory tensor being swizzled, and T3 should be
swizzled using *3* and *phase id*.
This PR proposes a hacky solution for this problem. We restore the
"correct" offset of each tile by **assuming that slicing on a specific
dim only happens at most once on the output of insert_slice_async**. I
admit it's risky and fragile.
The other possible solution is adopting cutlass' swizzling logic that
limits the indices being swizzled in a "bounding box" that matches the
mma instruction executes. For example, in the following tensor layout,
each 4x4 submatrix is a minimum swizzling unit, and the entire tensor
represents the tensor layout of operand A in `mma.16816`.
<img width="565" alt="image"
src="https://user-images.githubusercontent.com/2306281/201836879-4ca7824b-530c-4a06-a3d5-1e74a2de1b42.png">
Co-authored-by: Phil Tillet <phil@openai.com>
2022-11-19 19:57:16 -08:00
|
|
|
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
|
2022-11-10 13:57:27 +08:00
|
|
|
auto srcShared =
|
|
|
|
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
|
|
|
|
if (srcShared && srcShared.getVec() > 1)
|
|
|
|
return mlir::failure();
|
2022-08-18 12:49:37 -07:00
|
|
|
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
|
|
|
|
op, op->getResultTypes().front(), arg->getOperand(0));
|
|
|
|
return mlir::success();
|
|
|
|
}
|
|
|
|
// cvt(type1, splat(type2, x)) -> splat(type1, x)
|
|
|
|
if (auto splat = llvm::dyn_cast<triton::SplatOp>(arg)) {
|
|
|
|
rewriter.replaceOpWithNewOp<triton::SplatOp>(op, op->getResultTypes(),
|
|
|
|
splat.src());
|
|
|
|
return mlir::success();
|
|
|
|
}
|
|
|
|
// cvt(type1, make_range(type2, x)) -> make_range(type1, x)
|
|
|
|
if (auto range = llvm::dyn_cast<triton::MakeRangeOp>(arg)) {
|
|
|
|
rewriter.replaceOpWithNewOp<triton::MakeRangeOp>(
|
|
|
|
op, op->getResultTypes(), range.start(), range.end());
|
|
|
|
return mlir::success();
|
|
|
|
}
|
|
|
|
// cvt(type, constant) -> constant
|
|
|
|
if (auto cst = llvm::dyn_cast<arith::ConstantOp>(arg))
|
|
|
|
if (auto ret = cst.getValue().dyn_cast<SplatElementsAttr>()) {
|
|
|
|
auto newRet = SplatElementsAttr::get(op->getResultTypes().front(),
|
|
|
|
ret.getSplatValue<Attribute>());
|
|
|
|
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newRet);
|
|
|
|
return mlir::success();
|
|
|
|
}
|
|
|
|
return mlir::failure();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
// -----------------------------------------------------------------------------
|
|
|
|
//
|
|
|
|
// -----------------------------------------------------------------------------
|
|
|
|
|
2022-12-12 17:46:16 +08:00
|
|
|
LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
|
|
|
|
Attribute &ret) {
|
2022-10-16 21:19:42 -07:00
|
|
|
ret = targetEncoding;
|
|
|
|
if (auto expand_dims = dyn_cast<triton::ExpandDimsOp>(op)) {
|
|
|
|
ret = triton::gpu::SliceEncodingAttr::get(
|
|
|
|
op->getContext(), expand_dims.axis(), targetEncoding);
|
|
|
|
}
|
|
|
|
if (auto reduce = dyn_cast<triton::ReduceOp>(op)) {
|
|
|
|
auto sliceEncoding =
|
|
|
|
targetEncoding.dyn_cast<triton::gpu::SliceEncodingAttr>();
|
|
|
|
if (!sliceEncoding)
|
|
|
|
return failure();
|
|
|
|
ret = sliceEncoding.getParent();
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
inline bool expensive_to_remat(Operation *op) {
|
|
|
|
if (!op)
|
|
|
|
return true;
|
2022-11-06 22:59:03 -08:00
|
|
|
if (isa<tensor::ExtractSliceOp, triton::gpu::AllocTensorOp,
|
2022-10-16 21:19:42 -07:00
|
|
|
triton::gpu::InsertSliceAsyncOp, triton::LoadOp, triton::StoreOp,
|
2022-11-06 20:52:11 -08:00
|
|
|
triton::AtomicRMWOp, triton::AtomicCASOp, triton::DotOp>(op))
|
2022-10-16 21:19:42 -07:00
|
|
|
return true;
|
|
|
|
if (isa<scf::YieldOp, scf::ForOp>(op))
|
|
|
|
return true;
|
|
|
|
return false;
|
2022-12-12 17:46:16 +08:00
|
|
|
}
|
2022-10-16 21:19:42 -07:00
|
|
|
|
|
|
|
Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
|
|
|
|
BlockAndValueMapping &mapping) {
|
|
|
|
Operation *newOp = rewriter.clone(*op, mapping);
|
|
|
|
auto origType = op->getResult(0).getType().cast<RankedTensorType>();
|
|
|
|
auto newType = RankedTensorType::get(
|
|
|
|
origType.getShape(), origType.getElementType(),
|
|
|
|
newOp->getOperand(0).getType().cast<RankedTensorType>().getEncoding());
|
|
|
|
newOp->getResult(0).setType(newType);
|
|
|
|
auto typeInfer = dyn_cast<InferTypeOpInterface>(newOp);
|
|
|
|
if (typeInfer) {
|
|
|
|
SmallVector<Type, 1> newType;
|
2022-10-28 12:36:09 -07:00
|
|
|
auto success = typeInfer.inferReturnTypes(
|
2022-10-16 21:19:42 -07:00
|
|
|
newOp->getContext(), newOp->getLoc(), newOp->getOperands(),
|
|
|
|
newOp->getAttrDictionary(), newOp->getRegions(), newType);
|
2022-10-28 12:36:09 -07:00
|
|
|
if (succeeded(success))
|
2022-10-16 21:19:42 -07:00
|
|
|
newOp->getResult(0).setType(newType.front());
|
|
|
|
}
|
|
|
|
return newOp;
|
|
|
|
}
|
|
|
|
|
2022-08-18 12:49:37 -07:00
|
|
|
// Layout conversions are expensive. They require going through
|
|
|
|
// shared memory, which is orders of magnitude slower than
|
|
|
|
// other non-i/o operations in the dialect.
|
|
|
|
// It therefore makes sense to remove them whenever possible,
|
|
|
|
// even if it means rematerializing all values whose definitions
|
|
|
|
// are reachable from it without passing through any memory operation.
|
2022-10-16 21:19:42 -07:00
|
|
|
class RematerializeBackward : public mlir::RewritePattern {
|
2022-08-18 12:49:37 -07:00
|
|
|
public:
|
2022-12-12 17:46:16 +08:00
|
|
|
explicit RematerializeBackward(mlir::MLIRContext *context)
|
2022-08-18 12:49:37 -07:00
|
|
|
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
2022-09-02 16:52:44 -07:00
|
|
|
2, context) {}
|
2022-08-18 12:49:37 -07:00
|
|
|
|
|
|
|
mlir::LogicalResult
|
|
|
|
matchAndRewrite(mlir::Operation *cvt,
|
|
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
|
|
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(cvt))
|
|
|
|
return mlir::failure();
|
2022-10-16 21:19:42 -07:00
|
|
|
// we don't touch block arguments
|
2022-08-18 12:49:37 -07:00
|
|
|
Operation *op = cvt->getOperand(0).getDefiningOp();
|
|
|
|
if (!op)
|
|
|
|
return mlir::failure();
|
2022-10-16 21:19:42 -07:00
|
|
|
// we don't want to rematerialize any conversion to/from shared
|
2022-11-10 13:57:27 +08:00
|
|
|
if (isSharedEncoding(cvt->getResults()[0]) ||
|
|
|
|
isSharedEncoding(cvt->getOperand(0)))
|
2022-08-18 12:49:37 -07:00
|
|
|
return mlir::failure();
|
2022-10-21 16:52:15 -07:00
|
|
|
// we don't handle conversions to DotOperandEncodingAttr
|
2022-11-14 10:15:53 +08:00
|
|
|
// this is a heuristics to accommodate fused attention
|
2022-10-16 21:19:42 -07:00
|
|
|
auto targetType = cvt->getResultTypes()[0].cast<RankedTensorType>();
|
2022-10-21 16:52:15 -07:00
|
|
|
if (targetType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>())
|
|
|
|
return mlir::failure();
|
2022-10-16 21:19:42 -07:00
|
|
|
// DFS
|
|
|
|
SetVector<Operation *> processed;
|
|
|
|
SetVector<Attribute> layout;
|
2022-10-21 16:52:15 -07:00
|
|
|
llvm::MapVector<Value, Attribute> toConvert;
|
2022-10-16 21:19:42 -07:00
|
|
|
std::vector<std::pair<Operation *, Attribute>> queue;
|
2022-12-12 17:46:16 +08:00
|
|
|
queue.emplace_back(cvt, targetType.getEncoding());
|
2022-10-16 21:19:42 -07:00
|
|
|
int numCvts = 1;
|
|
|
|
while (!queue.empty()) {
|
|
|
|
Operation *currOp;
|
|
|
|
Attribute currLayout;
|
|
|
|
std::tie(currOp, currLayout) = queue.back();
|
|
|
|
queue.pop_back();
|
|
|
|
// If the current operation is expensive to rematerialize,
|
|
|
|
// we stop everything
|
|
|
|
if (expensive_to_remat(currOp))
|
|
|
|
break;
|
2022-11-14 10:15:53 +08:00
|
|
|
// a conversion will be removed here (i.e. transferred to operands)
|
2022-10-16 21:19:42 -07:00
|
|
|
numCvts -= 1;
|
|
|
|
// done processing
|
|
|
|
processed.insert(currOp);
|
|
|
|
layout.insert(currLayout);
|
|
|
|
// add all operands to the queue
|
|
|
|
for (Value argI : currOp->getOperands()) {
|
|
|
|
Attribute newEncoding;
|
2022-10-21 16:52:15 -07:00
|
|
|
// cannot invert the current encoding for this operand
|
|
|
|
// we stop everything
|
2022-10-16 21:19:42 -07:00
|
|
|
if (failed(invertEncoding(currLayout, currOp, newEncoding)))
|
|
|
|
return mlir::failure();
|
2022-10-21 16:52:15 -07:00
|
|
|
if (toConvert.count(argI) && toConvert[argI] != newEncoding)
|
|
|
|
return mlir::failure();
|
|
|
|
//
|
2022-10-16 21:19:42 -07:00
|
|
|
Operation *opArgI = argI.getDefiningOp();
|
2022-10-21 16:52:15 -07:00
|
|
|
toConvert.insert({argI, newEncoding});
|
2022-10-16 21:19:42 -07:00
|
|
|
if (!opArgI || processed.contains(opArgI) ||
|
|
|
|
(opArgI->getBlock() != cvt->getBlock()))
|
|
|
|
continue;
|
|
|
|
// if the conversion can be folded into opArgI then
|
2022-10-21 16:52:15 -07:00
|
|
|
// we don't count this conversion as expensive
|
2022-10-16 21:19:42 -07:00
|
|
|
if (isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
|
|
|
|
triton::MakeRangeOp, triton::SplatOp>(*opArgI))
|
|
|
|
continue;
|
|
|
|
// we add one expensive conversion for the current operand
|
|
|
|
numCvts += 1;
|
2022-12-12 17:46:16 +08:00
|
|
|
queue.emplace_back(opArgI, newEncoding);
|
2022-10-16 21:19:42 -07:00
|
|
|
}
|
2022-09-02 16:52:44 -07:00
|
|
|
}
|
2022-10-16 21:19:42 -07:00
|
|
|
// if rematerialization would add more conversions than it removes
|
|
|
|
// then we don't do it
|
|
|
|
if (numCvts > 0)
|
|
|
|
return mlir::failure();
|
|
|
|
|
2022-10-21 16:52:15 -07:00
|
|
|
SmallVector<Value, 4> sortedValues;
|
|
|
|
SetVector<Operation *> tmp;
|
2022-12-12 17:46:16 +08:00
|
|
|
for (auto &item : toConvert) {
|
|
|
|
Value v = item.first;
|
2022-10-21 16:52:15 -07:00
|
|
|
if (v.getDefiningOp())
|
|
|
|
tmp.insert(v.getDefiningOp());
|
|
|
|
else
|
|
|
|
sortedValues.push_back(v);
|
|
|
|
}
|
|
|
|
tmp = mlir::topologicalSort(tmp);
|
|
|
|
for (Operation *op : tmp)
|
|
|
|
sortedValues.push_back(op->getResult(0));
|
2022-08-18 12:49:37 -07:00
|
|
|
|
|
|
|
BlockAndValueMapping mapping;
|
2022-10-21 16:52:15 -07:00
|
|
|
for (Value currOperand : sortedValues) {
|
2022-10-16 21:19:42 -07:00
|
|
|
// unpack information
|
2022-10-21 16:52:15 -07:00
|
|
|
Attribute targetLayout = toConvert.lookup(currOperand);
|
2022-10-16 21:19:42 -07:00
|
|
|
// rematerialize the operand if necessary
|
|
|
|
Operation *currOperation = currOperand.getDefiningOp();
|
|
|
|
if (processed.contains(currOperation)) {
|
|
|
|
currOperation = cloneWithInferType(rewriter, currOperation, mapping);
|
|
|
|
currOperand = currOperation->getResult(0);
|
|
|
|
}
|
|
|
|
// compute target type for the layout cast
|
|
|
|
auto currType = currOperand.getType().cast<RankedTensorType>();
|
|
|
|
auto newType = RankedTensorType::get(
|
|
|
|
currType.getShape(), currType.getElementType(), targetLayout);
|
|
|
|
auto newOperand = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
|
|
|
currOperand.getLoc(), newType, currOperand);
|
|
|
|
if (currOperation)
|
|
|
|
newOperand->moveAfter(currOperation);
|
|
|
|
mapping.map(currOperand, newOperand);
|
2022-08-18 12:49:37 -07:00
|
|
|
}
|
2022-10-16 21:19:42 -07:00
|
|
|
rewriter.replaceOp(cvt, mapping.lookup(cvt->getOperand(0)));
|
2022-08-18 12:49:37 -07:00
|
|
|
return mlir::success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
// -----------------------------------------------------------------------------
|
|
|
|
//
|
|
|
|
// -----------------------------------------------------------------------------
|
|
|
|
|
2022-10-16 21:19:42 -07:00
|
|
|
class MoveConvertOutOfLoop : public mlir::RewritePattern {
|
2022-08-18 12:49:37 -07:00
|
|
|
public:
|
2022-12-12 17:46:16 +08:00
|
|
|
explicit MoveConvertOutOfLoop(mlir::MLIRContext *context)
|
2022-08-18 12:49:37 -07:00
|
|
|
: mlir::RewritePattern(scf::ForOp::getOperationName(), 1, context) {}
|
|
|
|
|
2022-10-21 16:52:15 -07:00
|
|
|
SmallVector<Value, 4>
|
|
|
|
rematerializeForLoop(mlir::PatternRewriter &rewriter, scf::ForOp &forOp,
|
|
|
|
size_t i, RankedTensorType newType,
|
|
|
|
triton::gpu::ConvertLayoutOp origConversion) const {
|
|
|
|
// Rewrite init argument
|
|
|
|
Type origType = forOp.getInitArgs()[i].getType();
|
|
|
|
SmallVector<Value, 4> newInitArgs = forOp.getInitArgs();
|
|
|
|
newInitArgs[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
|
|
|
newInitArgs[i].getLoc(), newType, newInitArgs[i]);
|
|
|
|
// Clone for loop
|
2022-12-12 17:46:16 +08:00
|
|
|
auto newForOp = rewriter.create<scf::ForOp>(
|
2022-10-21 16:52:15 -07:00
|
|
|
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
|
|
|
|
forOp.getStep(), newInitArgs);
|
|
|
|
newForOp->moveBefore(forOp);
|
|
|
|
rewriter.setInsertionPointToStart(newForOp.getBody());
|
|
|
|
BlockAndValueMapping mapping;
|
|
|
|
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
|
|
|
|
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
|
|
|
mapping.map(origConversion.getResult(), newForOp.getRegionIterArgs()[i]);
|
|
|
|
// the iter arg of interest may have other uses than the conversion
|
|
|
|
// we're hoisting out of the loop. If that's the case we will
|
|
|
|
// need to add extra conversions for all uses... which is only useful
|
|
|
|
// if these extra conversions can be removed by another pattern
|
|
|
|
auto oldArg = forOp.getRegionIterArgs()[i];
|
|
|
|
auto newArg = newForOp.getRegionIterArgs()[i];
|
|
|
|
auto newArgFallback = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
|
|
|
newForOp.getLoc(), origType, newArg);
|
|
|
|
|
|
|
|
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
|
|
|
|
for (Operation &op : forOp.getBody()->without_terminator()) {
|
|
|
|
if (&op == (Operation *)(&origConversion))
|
|
|
|
continue;
|
|
|
|
Operation *newOp = rewriter.clone(op, mapping);
|
|
|
|
if (find(oldArg.getUsers(), &op) != oldArg.getUsers().end())
|
|
|
|
newOp->replaceUsesOfWith(newArg, newArgFallback);
|
|
|
|
}
|
|
|
|
|
|
|
|
// create yield, inserting conversions if necessary
|
|
|
|
auto yieldOp = forOp.getBody()->getTerminator();
|
|
|
|
SmallVector<Value, 4> newYieldArgs;
|
|
|
|
for (Value arg : yieldOp->getOperands())
|
|
|
|
newYieldArgs.push_back(mapping.lookup(arg));
|
|
|
|
newYieldArgs[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
|
|
|
yieldOp->getLoc(), newType, newYieldArgs[i]);
|
|
|
|
rewriter.create<scf::YieldOp>(forOp.getLoc(), newYieldArgs);
|
|
|
|
|
|
|
|
// replace
|
|
|
|
SmallVector<Value, 4> newResults = newForOp->getResults();
|
|
|
|
newResults[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
|
|
|
rewriter.getUnknownLoc(), origType, newForOp->getResult(i));
|
|
|
|
newResults[i].getDefiningOp()->moveAfter(newForOp);
|
|
|
|
return newResults;
|
|
|
|
}
|
|
|
|
|
2022-10-28 12:36:09 -07:00
|
|
|
mlir::LogicalResult
|
|
|
|
matchAndRewrite(mlir::Operation *op,
|
|
|
|
mlir::PatternRewriter &rewriter) const override {
|
2022-08-18 12:49:37 -07:00
|
|
|
auto forOp = cast<scf::ForOp>(op);
|
|
|
|
auto iterArgs = forOp.getRegionIterArgs();
|
2022-12-12 17:46:16 +08:00
|
|
|
for (const auto &iterArg : llvm::enumerate(iterArgs)) {
|
2022-10-21 16:52:15 -07:00
|
|
|
// if (iterArg.index() != 1)
|
|
|
|
// continue;
|
2022-10-16 21:19:42 -07:00
|
|
|
// skip non-tensor types
|
|
|
|
if (!iterArg.value().getType().isa<RankedTensorType>())
|
|
|
|
continue;
|
2022-10-21 16:52:15 -07:00
|
|
|
// we only move `iterArg` out of the loop if
|
|
|
|
// - there is only a single conversion use
|
|
|
|
// - moving this conversion out of the loop will not generate
|
|
|
|
// any extra non-removable conversion
|
|
|
|
auto users = iterArg.value().getUsers();
|
|
|
|
// check first condition
|
|
|
|
SetVector<Type> cvtTargetTypes;
|
2022-11-10 13:57:27 +08:00
|
|
|
for (auto user : users) {
|
|
|
|
if (isa<triton::gpu::ConvertLayoutOp>(user)) {
|
|
|
|
auto newType =
|
|
|
|
user->getResults()[0].getType().cast<RankedTensorType>();
|
|
|
|
auto oldType = user->getOperand(0).getType().cast<RankedTensorType>();
|
|
|
|
if (oldType.getEncoding().isa<triton::gpu::SharedEncodingAttr>() &&
|
|
|
|
newType.getEncoding()
|
|
|
|
.isa<triton::gpu::DotOperandEncodingAttr>()) {
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
if (newType.getEncoding().isa<triton::gpu::SharedEncodingAttr>()) {
|
|
|
|
if (newType.getEncoding()
|
|
|
|
.cast<triton::gpu::SharedEncodingAttr>()
|
|
|
|
.getVec() == 1)
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
cvtTargetTypes.insert(newType);
|
|
|
|
}
|
|
|
|
}
|
2022-10-21 16:52:15 -07:00
|
|
|
if (cvtTargetTypes.size() != 1)
|
|
|
|
continue;
|
|
|
|
// TODO: check second condition
|
|
|
|
for (auto user : users) {
|
|
|
|
if (isa<triton::gpu::ConvertLayoutOp>(user))
|
|
|
|
continue;
|
|
|
|
}
|
2022-10-16 21:19:42 -07:00
|
|
|
// check
|
2022-11-10 13:57:27 +08:00
|
|
|
// llvm::outs() << "replacing " << iterArg.index() << "\n";
|
2022-08-18 12:49:37 -07:00
|
|
|
for (auto op : iterArg.value().getUsers()) {
|
2022-10-21 16:52:15 -07:00
|
|
|
auto cvt = dyn_cast<triton::gpu::ConvertLayoutOp>(op);
|
|
|
|
if (!cvt)
|
|
|
|
continue;
|
|
|
|
auto targetType = op->getResultTypes()[0].cast<RankedTensorType>();
|
|
|
|
auto newFor = rematerializeForLoop(rewriter, forOp, iterArg.index(),
|
|
|
|
targetType, cvt);
|
|
|
|
rewriter.replaceOp(forOp, newFor);
|
|
|
|
return success();
|
2022-08-18 12:49:37 -07:00
|
|
|
}
|
|
|
|
}
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
// -----------------------------------------------------------------------------
|
|
|
|
//
|
|
|
|
// -----------------------------------------------------------------------------
|
|
|
|
|
2022-10-16 21:19:42 -07:00
|
|
|
class RematerializeForward : public mlir::RewritePattern {
|
2022-08-18 12:49:37 -07:00
|
|
|
public:
|
2022-12-12 17:46:16 +08:00
|
|
|
explicit RematerializeForward(mlir::MLIRContext *context)
|
2022-08-18 12:49:37 -07:00
|
|
|
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
|
|
|
2, context) {}
|
|
|
|
|
|
|
|
mlir::LogicalResult
|
|
|
|
matchAndRewrite(mlir::Operation *_cvtOp,
|
|
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
|
|
auto cvt = cast<triton::gpu::ConvertLayoutOp>(_cvtOp);
|
|
|
|
auto forOp = dyn_cast<scf::ForOp>(cvt->getParentOp());
|
|
|
|
if (!forOp)
|
|
|
|
return mlir::failure();
|
|
|
|
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
|
|
|
|
|
|
|
|
SetVector<Operation *> cvtSlices;
|
|
|
|
auto filter = [&](Operation *op) {
|
2022-11-06 20:52:11 -08:00
|
|
|
return isInLoop(op) &&
|
|
|
|
!isa<triton::LoadOp, triton::StoreOp, triton::AtomicRMWOp,
|
|
|
|
triton::AtomicCASOp>(op) &&
|
2022-08-18 12:49:37 -07:00
|
|
|
!isa<triton::DotOp>(op) && !isa<scf::YieldOp>(op) &&
|
|
|
|
!isa<triton::gpu::ConvertLayoutOp>(op);
|
|
|
|
};
|
|
|
|
mlir::getForwardSlice(cvt.getResult(), &cvtSlices, filter);
|
|
|
|
if (cvtSlices.empty())
|
|
|
|
return failure();
|
2022-10-21 16:52:15 -07:00
|
|
|
|
|
|
|
for (Operation *op : cvtSlices) {
|
|
|
|
if (!op->hasTrait<mlir::OpTrait::SameOperandsAndResultEncoding>() &&
|
|
|
|
!op->hasTrait<mlir::OpTrait::SameOperandsAndResultType>())
|
2022-08-18 12:49:37 -07:00
|
|
|
return failure();
|
2022-10-21 16:52:15 -07:00
|
|
|
for (Value arg : op->getOperands()) {
|
|
|
|
Operation *argOp = arg.getDefiningOp();
|
|
|
|
if (argOp && (argOp != cvt) &&
|
|
|
|
!isa<arith::ConstantOp, triton::SplatOp>(argOp)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
}
|
2022-08-18 12:49:37 -07:00
|
|
|
}
|
2022-10-21 16:52:15 -07:00
|
|
|
|
2022-08-18 12:49:37 -07:00
|
|
|
// otherwise, we push the conversion forward
|
|
|
|
// since we'll be able to move it out of
|
|
|
|
// the loop once it reaches the yield op
|
|
|
|
// op(cvt(arg_0), arg_1, ..., arg_n)
|
|
|
|
// -> cvt(op(arg_0, cvt(arg_1), ..., cvt(arg_n)))
|
|
|
|
BlockAndValueMapping mapping;
|
2022-10-21 16:52:15 -07:00
|
|
|
auto op = cvtSlices.front();
|
2022-08-18 12:49:37 -07:00
|
|
|
for (Value arg : op->getOperands()) {
|
|
|
|
if (arg.getDefiningOp() == cvt)
|
|
|
|
mapping.map(arg, cvt.getOperand());
|
|
|
|
else {
|
|
|
|
auto cvtI = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
|
|
|
arg.getLoc(), cvt.getOperand().getType(), arg);
|
|
|
|
mapping.map(arg, cvtI);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
Operation *newOp = rewriter.clone(*op, mapping);
|
|
|
|
newOp->getResult(0).setType(cvt.getOperand().getType());
|
|
|
|
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
|
|
|
newOp->getLoc(), cvt.getResult().getType(), newOp->getResult(0));
|
|
|
|
rewriter.replaceOp(op, newCvt->getResults());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
// -----------------------------------------------------------------------------
|
|
|
|
//
|
|
|
|
// -----------------------------------------------------------------------------
|
2022-11-24 09:44:29 +08:00
|
|
|
namespace {
|
2022-12-12 17:46:16 +08:00
|
|
|
int computeCapabilityToMMAVersion(int computeCapability) {
|
2022-11-23 03:08:23 +08:00
|
|
|
if (computeCapability < 80) {
|
|
|
|
return 1;
|
|
|
|
} else if (computeCapability < 90) {
|
|
|
|
return 2;
|
|
|
|
} else {
|
|
|
|
assert(false && "computeCapability > 90 not supported");
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-12-12 17:46:16 +08:00
|
|
|
SmallVector<int64_t, 2> mmaVersionToShapePerWarp(int version) {
|
2022-11-24 09:44:29 +08:00
|
|
|
if (version == 1)
|
2022-11-23 03:08:23 +08:00
|
|
|
return {16, 16};
|
2022-11-24 09:44:29 +08:00
|
|
|
else if (version == 2)
|
2022-11-23 03:08:23 +08:00
|
|
|
return {16, 8};
|
2022-11-24 09:44:29 +08:00
|
|
|
else {
|
2022-11-23 03:08:23 +08:00
|
|
|
assert(false && "version not supported");
|
|
|
|
return {0, 0};
|
|
|
|
}
|
|
|
|
}
|
2022-08-18 12:49:37 -07:00
|
|
|
|
2022-12-12 17:46:16 +08:00
|
|
|
SmallVector<unsigned, 2> warpsPerTileV1(const ArrayRef<int64_t> shape,
|
2022-12-01 11:54:18 -08:00
|
|
|
int numWarps) {
|
2022-11-24 09:44:29 +08:00
|
|
|
SmallVector<unsigned, 2> ret = {1, 1};
|
|
|
|
SmallVector<int64_t, 2> shapePerWarp =
|
2022-12-12 17:46:16 +08:00
|
|
|
mmaVersionToShapePerWarp(1 /*version*/);
|
2022-11-24 09:44:29 +08:00
|
|
|
bool changed = false;
|
|
|
|
do {
|
|
|
|
changed = false;
|
2022-12-09 23:41:22 +08:00
|
|
|
int pre = ret[0];
|
2022-11-24 09:44:29 +08:00
|
|
|
if (ret[0] * ret[1] < numWarps) {
|
|
|
|
ret[0] = std::clamp<unsigned>(ret[0] * 2, 1, shape[0] / shapePerWarp[0]);
|
2022-12-09 23:41:22 +08:00
|
|
|
changed = pre != ret[0];
|
2022-11-24 09:44:29 +08:00
|
|
|
}
|
|
|
|
if (ret[0] * ret[1] < numWarps) {
|
2022-12-09 23:41:22 +08:00
|
|
|
pre = ret[1];
|
2022-11-24 09:44:29 +08:00
|
|
|
ret[1] = std::clamp<unsigned>(ret[1] * 2, 1, shape[1] / shapePerWarp[1]);
|
2022-12-09 23:41:22 +08:00
|
|
|
changed = pre != ret[1];
|
2022-11-24 09:44:29 +08:00
|
|
|
}
|
|
|
|
} while (changed);
|
|
|
|
return ret;
|
|
|
|
}
|
|
|
|
|
2022-11-30 11:13:24 +01:00
|
|
|
SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
|
2022-12-01 11:54:18 -08:00
|
|
|
const ArrayRef<int64_t> shape,
|
|
|
|
int numWarps) {
|
|
|
|
SetVector<Operation *> slices;
|
|
|
|
mlir::getForwardSlice(dotOp.getResult(), &slices);
|
|
|
|
if (llvm::find_if(slices, [](Operation *op) {
|
|
|
|
return isa<triton::DotOp>(op);
|
|
|
|
}) != slices.end())
|
|
|
|
return {(unsigned)numWarps, 1};
|
|
|
|
|
|
|
|
SmallVector<unsigned, 2> ret = {1, 1};
|
|
|
|
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
|
|
|
|
bool changed = false;
|
|
|
|
// TODO (@daadaada): double-check.
|
|
|
|
// original logic in
|
|
|
|
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
|
|
|
|
// seems buggy for shape = [32, 16] ?
|
|
|
|
do {
|
|
|
|
changed = false;
|
|
|
|
if (ret[0] * ret[1] >= numWarps)
|
|
|
|
break;
|
|
|
|
if (shape[0] / shapePerWarp[0] / ret[0] >=
|
|
|
|
shape[1] / (shapePerWarp[1] * 2) / ret[1]) {
|
|
|
|
if (ret[0] < shape[0] / shapePerWarp[0]) {
|
|
|
|
ret[0] *= 2;
|
|
|
|
} else
|
2022-11-24 09:44:29 +08:00
|
|
|
ret[1] *= 2;
|
2022-12-01 11:54:18 -08:00
|
|
|
} else {
|
|
|
|
ret[1] *= 2;
|
|
|
|
}
|
|
|
|
} while (true);
|
|
|
|
return ret;
|
2022-11-24 09:44:29 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
} // namespace
|
2022-11-30 11:13:24 +01:00
|
|
|
|
2022-12-03 09:58:24 -08:00
|
|
|
class OptimizeBlockedToShared : public mlir::RewritePattern {
|
|
|
|
public:
|
2022-12-12 17:46:16 +08:00
|
|
|
explicit OptimizeBlockedToShared(mlir::MLIRContext *context)
|
2022-12-03 09:58:24 -08:00
|
|
|
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
|
|
|
|
context) {}
|
|
|
|
|
|
|
|
mlir::LogicalResult
|
|
|
|
matchAndRewrite(mlir::Operation *op,
|
|
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
|
|
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
|
|
|
|
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
|
|
|
|
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
|
|
|
|
auto srcBlockedLayout =
|
|
|
|
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
|
|
|
auto dstSharedLayout =
|
|
|
|
dstType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
|
|
|
|
if (!srcBlockedLayout || !dstSharedLayout)
|
|
|
|
return failure();
|
|
|
|
if (srcBlockedLayout.getOrder() == dstSharedLayout.getOrder())
|
|
|
|
return failure();
|
|
|
|
// For now only works if single use is transpose
|
|
|
|
// TODO: rematerialize #shared uses
|
|
|
|
auto users = op->getUsers();
|
|
|
|
if (std::distance(users.begin(), users.end()) != 1 ||
|
|
|
|
!isa<triton::TransOp>(*users.begin()))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
auto tmpShared = triton::gpu::SharedEncodingAttr::get(
|
|
|
|
op->getContext(), dstSharedLayout.getVec(),
|
|
|
|
dstSharedLayout.getPerPhase(), dstSharedLayout.getMaxPhase(),
|
|
|
|
srcBlockedLayout.getOrder());
|
|
|
|
auto tmpType = RankedTensorType::get(srcType.getShape(),
|
|
|
|
srcType.getElementType(), tmpShared);
|
|
|
|
auto tmpCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
|
|
|
op->getLoc(), tmpType, cvt.getOperand());
|
|
|
|
|
|
|
|
auto newDstType = RankedTensorType::get(
|
|
|
|
users.begin()->getResultTypes()[0].cast<RankedTensorType>().getShape(),
|
|
|
|
srcType.getElementType(), dstSharedLayout);
|
|
|
|
|
|
|
|
auto newTrans = rewriter.create<triton::TransOp>(op->getLoc(), newDstType,
|
|
|
|
tmpCvt.getResult());
|
|
|
|
|
|
|
|
rewriter.replaceOp(*users.begin(), newTrans.getResult());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-12-11 19:01:57 -08:00
|
|
|
class OptimizeConvertToDotOperand : public mlir::RewritePattern {
|
|
|
|
public:
|
2022-12-12 17:46:16 +08:00
|
|
|
explicit OptimizeConvertToDotOperand(mlir::MLIRContext *context)
|
2022-12-11 19:01:57 -08:00
|
|
|
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
|
|
|
|
context) {}
|
|
|
|
|
|
|
|
mlir::LogicalResult
|
|
|
|
matchAndRewrite(mlir::Operation *op,
|
|
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
|
|
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
|
|
|
|
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
|
|
|
|
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
|
|
|
|
// order
|
|
|
|
ArrayRef<unsigned> order;
|
2022-12-12 17:46:16 +08:00
|
|
|
if (auto srcBlockedLayout =
|
|
|
|
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
|
2022-12-11 19:01:57 -08:00
|
|
|
order = srcBlockedLayout.getOrder();
|
2022-12-12 17:46:16 +08:00
|
|
|
else if (auto srcSharedLayout =
|
|
|
|
srcType.getEncoding()
|
|
|
|
.dyn_cast<triton::gpu::SharedEncodingAttr>())
|
2022-12-11 19:01:57 -08:00
|
|
|
order = srcSharedLayout.getOrder();
|
|
|
|
else
|
|
|
|
return failure();
|
|
|
|
// dot operand output
|
|
|
|
auto dstDotOperandLayout =
|
|
|
|
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
|
|
|
if (!dstDotOperandLayout)
|
|
|
|
return failure();
|
2022-12-12 17:46:16 +08:00
|
|
|
if (!dstDotOperandLayout.getIsMMAv1Row())
|
2022-12-11 19:01:57 -08:00
|
|
|
return failure();
|
2022-12-12 17:46:16 +08:00
|
|
|
bool isMMAv1Row =
|
|
|
|
dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
|
|
|
if ((order[0] == 1 && isMMAv1Row) || (order[0] == 0 && !isMMAv1Row))
|
2022-12-11 19:01:57 -08:00
|
|
|
return failure();
|
|
|
|
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
|
|
|
|
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
|
2022-12-12 17:46:16 +08:00
|
|
|
op->getContext(), dstDotOperandLayout.getOpIdx(),
|
|
|
|
dstDotOperandLayout.getParent(), newIsRow);
|
2022-12-11 19:01:57 -08:00
|
|
|
auto newDstType = RankedTensorType::get(
|
2022-12-12 17:46:16 +08:00
|
|
|
dstType.getShape(), dstType.getElementType(), newDstEncoding);
|
2022-12-11 19:01:57 -08:00
|
|
|
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
|
|
|
op->getLoc(), newDstType, cvt.getOperand());
|
|
|
|
rewriter.replaceOp(op, newCvt.getResult());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-08-18 12:49:37 -07:00
|
|
|
class BlockedToMMA : public mlir::RewritePattern {
|
2022-11-23 03:08:23 +08:00
|
|
|
int computeCapability;
|
|
|
|
|
2022-08-18 12:49:37 -07:00
|
|
|
public:
|
2022-11-23 03:08:23 +08:00
|
|
|
BlockedToMMA(mlir::MLIRContext *context, int computeCapability)
|
|
|
|
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context),
|
|
|
|
computeCapability(computeCapability) {}
|
2022-08-18 12:49:37 -07:00
|
|
|
|
2022-11-30 11:13:24 +01:00
|
|
|
static SmallVector<unsigned, 2> getWarpsPerTile(triton::DotOp dotOp,
|
|
|
|
const ArrayRef<int64_t> shape,
|
2022-11-24 09:44:29 +08:00
|
|
|
int version, int numWarps) {
|
|
|
|
switch (version) {
|
|
|
|
case 1:
|
2022-12-12 17:46:16 +08:00
|
|
|
return warpsPerTileV1(shape, numWarps);
|
2022-11-24 09:44:29 +08:00
|
|
|
case 2:
|
2022-11-30 11:13:24 +01:00
|
|
|
return warpsPerTileV2(dotOp, shape, numWarps);
|
2022-11-24 09:44:29 +08:00
|
|
|
default:
|
|
|
|
assert(false && "not supported version");
|
|
|
|
return {0, 0};
|
|
|
|
}
|
2022-11-02 22:58:09 -07:00
|
|
|
}
|
|
|
|
|
2022-08-18 12:49:37 -07:00
|
|
|
mlir::LogicalResult
|
|
|
|
matchAndRewrite(mlir::Operation *op,
|
|
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
|
|
auto dotOp = cast<triton::DotOp>(op);
|
|
|
|
// TODO: Check data-types and SM compatibility
|
|
|
|
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
|
|
|
|
if (oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
|
|
|
|
return failure();
|
2022-11-09 12:57:50 +08:00
|
|
|
|
2022-12-12 15:52:16 -08:00
|
|
|
int version = computeCapabilityToMMAVersion(computeCapability);
|
|
|
|
|
2022-11-09 12:57:50 +08:00
|
|
|
// for FMA, should retain the blocked layout.
|
2022-12-12 15:52:16 -08:00
|
|
|
if (!supportMMA(dotOp, version))
|
2022-11-09 12:57:50 +08:00
|
|
|
return failure();
|
|
|
|
|
2022-11-02 22:58:09 -07:00
|
|
|
// get MMA encoding for the given number of warps
|
|
|
|
auto retShape = oldRetType.getShape();
|
|
|
|
auto mod = op->getParentOfType<mlir::ModuleOp>();
|
|
|
|
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
2022-12-06 10:57:08 +08:00
|
|
|
|
2022-11-23 03:08:23 +08:00
|
|
|
auto newRetType = RankedTensorType::get(
|
|
|
|
retShape, oldRetType.getElementType(),
|
|
|
|
triton::gpu::MmaEncodingAttr::get(
|
|
|
|
oldRetType.getContext(), version,
|
2022-11-30 11:13:24 +01:00
|
|
|
getWarpsPerTile(dotOp, retShape, version, numWarps)));
|
2022-11-02 22:58:09 -07:00
|
|
|
// convert accumulator
|
2022-08-18 12:49:37 -07:00
|
|
|
auto oldAcc = dotOp.getOperand(2);
|
|
|
|
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
|
|
|
oldAcc.getLoc(), newRetType, oldAcc);
|
2022-11-10 13:57:27 +08:00
|
|
|
Value a = dotOp.a();
|
|
|
|
Value b = dotOp.b();
|
|
|
|
auto oldAType = a.getType().cast<RankedTensorType>();
|
|
|
|
auto oldBType = b.getType().cast<RankedTensorType>();
|
2022-12-12 17:46:16 +08:00
|
|
|
auto oldAOrder = oldAType.getEncoding()
|
|
|
|
.cast<triton::gpu::DotOperandEncodingAttr>()
|
|
|
|
.getParent()
|
|
|
|
.cast<triton::gpu::BlockedEncodingAttr>()
|
|
|
|
.getOrder();
|
|
|
|
auto oldBOrder = oldBType.getEncoding()
|
|
|
|
.cast<triton::gpu::DotOperandEncodingAttr>()
|
|
|
|
.getParent()
|
|
|
|
.cast<triton::gpu::BlockedEncodingAttr>()
|
|
|
|
.getOrder();
|
2022-12-11 19:01:57 -08:00
|
|
|
Attribute isMMAv1RowA;
|
|
|
|
Attribute isMMAv1RowB;
|
2022-12-12 17:46:16 +08:00
|
|
|
if (version == 1) {
|
2022-12-11 19:01:57 -08:00
|
|
|
isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1);
|
|
|
|
isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1);
|
|
|
|
}
|
|
|
|
|
2022-11-10 13:57:27 +08:00
|
|
|
auto newAType = RankedTensorType::get(
|
|
|
|
oldAType.getShape(), oldAType.getElementType(),
|
2022-12-12 17:46:16 +08:00
|
|
|
triton::gpu::DotOperandEncodingAttr::get(
|
|
|
|
oldAType.getContext(), 0, newRetType.getEncoding(), isMMAv1RowA));
|
2022-11-10 13:57:27 +08:00
|
|
|
auto newBType = RankedTensorType::get(
|
|
|
|
oldBType.getShape(), oldBType.getElementType(),
|
2022-12-12 17:46:16 +08:00
|
|
|
triton::gpu::DotOperandEncodingAttr::get(
|
|
|
|
oldBType.getContext(), 1, newRetType.getEncoding(), isMMAv1RowB));
|
2022-12-11 19:01:57 -08:00
|
|
|
|
2022-11-10 13:57:27 +08:00
|
|
|
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
|
|
|
|
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
|
2022-12-10 20:34:58 -08:00
|
|
|
auto newDot = rewriter.create<triton::DotOp>(dotOp.getLoc(), newRetType, a,
|
|
|
|
b, newAcc, dotOp.allowTF32());
|
2022-08-18 12:49:37 -07:00
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
|
|
|
|
op, oldRetType, newDot.getResult());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-12-11 19:01:57 -08:00
|
|
|
class FixupLoop : public mlir::RewritePattern {
|
|
|
|
|
|
|
|
public:
|
2022-12-12 17:46:16 +08:00
|
|
|
explicit FixupLoop(mlir::MLIRContext *context)
|
|
|
|
: mlir::RewritePattern(scf::ForOp::getOperationName(), 2, context) {}
|
2022-12-11 19:01:57 -08:00
|
|
|
|
|
|
|
mlir::LogicalResult
|
|
|
|
matchAndRewrite(mlir::Operation *op,
|
|
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
|
|
auto forOp = cast<scf::ForOp>(op);
|
|
|
|
|
|
|
|
// Rewrite init argument
|
|
|
|
SmallVector<Value, 4> newInitArgs = forOp.getInitArgs();
|
|
|
|
bool shouldRematerialize = false;
|
2022-12-12 17:46:16 +08:00
|
|
|
for (size_t i = 0; i < newInitArgs.size(); i++) {
|
2022-12-11 19:01:57 -08:00
|
|
|
auto initArg = newInitArgs[i];
|
|
|
|
auto regionArg = forOp.getRegionIterArgs()[i];
|
2022-12-12 17:46:16 +08:00
|
|
|
if (newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType()) {
|
2022-12-11 19:01:57 -08:00
|
|
|
shouldRematerialize = true;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
2022-12-12 17:46:16 +08:00
|
|
|
if (!shouldRematerialize)
|
2022-12-11 19:01:57 -08:00
|
|
|
return failure();
|
2022-12-12 17:46:16 +08:00
|
|
|
|
2022-12-11 19:01:57 -08:00
|
|
|
scf::ForOp newForOp = rewriter.create<scf::ForOp>(
|
|
|
|
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
|
|
|
|
forOp.getStep(), newInitArgs);
|
|
|
|
newForOp->moveBefore(forOp);
|
|
|
|
rewriter.setInsertionPointToStart(newForOp.getBody());
|
|
|
|
BlockAndValueMapping mapping;
|
|
|
|
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
|
|
|
|
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
|
|
|
|
|
|
|
for (Operation &op : forOp.getBody()->getOperations()) {
|
|
|
|
Operation *newOp = rewriter.clone(op, mapping);
|
|
|
|
}
|
|
|
|
rewriter.replaceOp(forOp, newForOp.getResults());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-08-18 12:49:37 -07:00
|
|
|
} // namespace
|
|
|
|
|
2022-05-16 19:17:15 +08:00
|
|
|
#define GEN_PASS_CLASSES
|
|
|
|
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
|
|
|
|
2022-07-26 17:25:03 -07:00
|
|
|
class TritonGPUCombineOpsPass
|
|
|
|
: public TritonGPUCombineOpsBase<TritonGPUCombineOpsPass> {
|
2022-05-16 19:17:15 +08:00
|
|
|
public:
|
2022-11-23 03:08:23 +08:00
|
|
|
TritonGPUCombineOpsPass() = default;
|
|
|
|
TritonGPUCombineOpsPass(int computeCapability) {
|
|
|
|
this->computeCapability = computeCapability;
|
|
|
|
}
|
2022-05-16 19:17:15 +08:00
|
|
|
void runOnOperation() override {
|
|
|
|
MLIRContext *context = &getContext();
|
|
|
|
ModuleOp m = getOperation();
|
|
|
|
|
|
|
|
mlir::RewritePatternSet patterns(context);
|
|
|
|
|
2022-12-03 09:58:24 -08:00
|
|
|
patterns.add<OptimizeBlockedToShared>(context);
|
2022-12-11 19:01:57 -08:00
|
|
|
patterns.add<OptimizeConvertToDotOperand>(context);
|
2022-08-18 12:49:37 -07:00
|
|
|
patterns.add<SimplifyConversion>(context);
|
2022-11-30 11:13:24 +01:00
|
|
|
patterns.add<DecomposeDotOperand>(context);
|
2022-10-16 21:19:42 -07:00
|
|
|
patterns.add<RematerializeBackward>(context);
|
|
|
|
patterns.add<RematerializeForward>(context);
|
|
|
|
patterns.add<MoveConvertOutOfLoop>(context);
|
2022-11-23 03:08:23 +08:00
|
|
|
patterns.add<BlockedToMMA>(context, computeCapability);
|
2022-05-16 19:17:15 +08:00
|
|
|
|
2022-10-16 21:19:42 -07:00
|
|
|
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
|
2022-05-16 19:17:15 +08:00
|
|
|
signalPassFailure();
|
2022-10-16 21:19:42 -07:00
|
|
|
}
|
2022-12-11 19:01:57 -08:00
|
|
|
|
|
|
|
// llvm::outs() << m << "\n";
|
|
|
|
mlir::RewritePatternSet loopFixup(context);
|
|
|
|
loopFixup.add<FixupLoop>(context);
|
|
|
|
if (applyPatternsAndFoldGreedily(m, std::move(loopFixup)).failed()) {
|
|
|
|
signalPassFailure();
|
|
|
|
}
|
2022-05-16 19:17:15 +08:00
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-11-23 03:08:23 +08:00
|
|
|
std::unique_ptr<Pass>
|
|
|
|
mlir::createTritonGPUCombineOpsPass(int computeCapability) {
|
|
|
|
return std::make_unique<TritonGPUCombineOpsPass>(computeCapability);
|
2022-11-09 12:57:50 +08:00
|
|
|
}
|