[OPTIMIZER] Added memory coalescing pass (#31)

This commit is contained in:
Philippe Tillet
2022-07-31 20:59:31 -07:00
committed by GitHub
parent d1593e6ca8
commit 3236642e8f
7 changed files with 180 additions and 4 deletions

View File

@@ -6,6 +6,8 @@
namespace mlir {
std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages = 2);
std::unique_ptr<Pass> createTritonGPUCoalescePass();
std::unique_ptr<Pass> createTritonGPUCombineOpsPass();
std::unique_ptr<Pass> createTritonGPUVerifier();

View File

@@ -23,6 +23,24 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
];
}
def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> {
let summary = "coalesce";
let description = [{
TODO
}];
let constructor = "mlir::createTritonGPUCoalescePass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
let options = [
Option<"numWarps", "num-warps",
"int32_t", /*default*/"4",
"number of warps">
];
}
def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> {
let summary = "combine triton gpu ops";

View File

@@ -94,7 +94,8 @@ ChangeResult AxisInfoAnalysis::visitOperation(
AxisInfo curr;
// This preserves the input axes (e.g., cast):
if (llvm::isa<arith::ExtSIOp, arith::ExtUIOp, arith::TruncIOp,
triton::PtrToIntOp, triton::IntToPtrOp>(op))
triton::PtrToIntOp, triton::IntToPtrOp,
triton::gpu::ConvertLayoutOp>(op))
curr = operands[0]->getValue();
// Constant ranges
if (triton::MakeRangeOp make_range =

View File

@@ -3,6 +3,7 @@ mlir_tablegen(TritonGPUCombine.inc -gen-rewriters)
add_public_tablegen_target(TritonGPUCombineIncGen)
add_mlir_dialect_library(TritonGPUTransforms
Coalesce.cpp
Combine.cpp
Pipeline.cpp
Verifier.cpp

View File

@@ -0,0 +1,108 @@
#include "triton/Analysis/AxisInfo.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include <numeric>
using namespace mlir;
using namespace mlir::triton;
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
Attribute getCoalescedEncoding(AxisInfoAnalysis &axisInfo, Value ptr) {
auto origType = ptr.getType().cast<RankedTensorType>();
// Get the shape of the tensor.
size_t rank = origType.getRank();
AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue();
// Layout order in decreasing order of contiguity
SmallVector<unsigned, 4> order(rank);
std::iota(order.begin(), order.end(), 0);
auto contiguity = info.getContiguity();
std::sort(order.begin(), order.end(), [&](unsigned x, unsigned y) {
return contiguity[x] > contiguity[y];
});
// Thread tile size depends on memory alignment
SmallVector<unsigned, 4> sizePerThread(rank, 1);
PointerType ptrType = origType.getElementType().cast<PointerType>();
unsigned numBits = ptrType.getPointeeType().getIntOrFloatBitWidth();
unsigned maxMultiple = info.getDivisibility(order[0]);
unsigned maxContig = info.getContiguity(order[0]);
unsigned alignment = std::min(maxMultiple, maxContig);
unsigned perThread = std::min(alignment, 128 / numBits);
sizePerThread[order[0]] = perThread;
// create encoding
Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get(
&getContext(), origType.getShape(), sizePerThread, order,
this->numWarps);
return encoding;
}
std::function<Type(Type)> getTypeConverter(AxisInfoAnalysis &axisInfo,
Value ptr) {
Attribute encoding = getCoalescedEncoding(axisInfo, ptr);
return [encoding](Type _type) {
RankedTensorType type = _type.cast<RankedTensorType>();
return RankedTensorType::get(type.getShape(), type.getElementType(),
encoding);
};
}
template <class T>
void coalesceOp(AxisInfoAnalysis &axisInfo, Operation *op, Value ptr,
OpBuilder builder) {
RankedTensorType ty = ptr.getType().template dyn_cast<RankedTensorType>();
if (!ty)
return;
AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue();
auto convertType = getTypeConverter(axisInfo, ptr);
// convert operands
SmallVector<Value, 4> newArgs;
for (auto v : op->getOperands())
newArgs.push_back(builder.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), convertType(v.getType()), v));
// convert output types
SmallVector<Type, 4> newTypes;
for (auto t : op->getResultTypes())
newTypes.push_back(convertType(t));
// construct new op with the new encoding
Operation *newOp =
builder.create<T>(op->getLoc(), newTypes, newArgs, op->getAttrs());
// cast the results back to the original layout
for (size_t i = 0; i < op->getNumResults(); i++) {
auto newResult = builder.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), op->getResult(i).getType(), newOp->getResult(i));
op->getResult(i).replaceAllUsesWith(newResult);
}
op->erase();
}
void runOnOperation() override {
Operation *op = getOperation();
// Run axis info analysis
AxisInfoAnalysis axisInfo(&getContext());
axisInfo.run(op);
OpBuilder builder(op);
// For each memory op that has a layout L1:
// 1. Create a coalesced memory layout L2 of the pointer operands
// 2. Convert all operands from layout L1 to layout L2
// 3. Create a new memory op that consumes these operands and
// produces a tensor with layout L2
// 4. Convert the output of this new memory op back to L1
// 5. Replace all the uses of the original memory op by the new one
op->walk([&](Operation *curr) {
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPoint(curr);
if (auto load = dyn_cast<triton::LoadOp>(curr))
coalesceOp<triton::LoadOp>(axisInfo, curr, load.ptr(), builder);
if (auto store = dyn_cast<triton::StoreOp>(curr))
coalesceOp<triton::StoreOp>(axisInfo, curr, store.ptr(), builder);
});
}
};
std::unique_ptr<Pass> mlir::createTritonGPUCoalescePass() {
return std::make_unique<CoalescePass>();
}

View File

@@ -5,8 +5,8 @@ import triton.language as tl
# triton kernel
@triton.jit
def kernel(X, stride_xm, stride_xn,
Z, stride_zm, stride_zn,
def kernel(X, stride_xm,
Z, stride_zn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N)
@@ -15,5 +15,5 @@ def kernel(X, stride_xm, stride_xn,
tl.store(Zs, tl.load(Xs))
ret = triton.compile(kernel, "*fp32,i32,i32,*fp32,i32,i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}, output="ttgir")
ret = triton.compile(kernel, "*fp32,i32,*fp32,i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}, output="ttgir")
print(ret)

View File

@@ -0,0 +1,46 @@
// RUN: triton-opt %s -split-input-file -tritongpu-coalesce -canonicalize -tritongpu-verifier | FileCheck %s
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
// CHECK: [[load_ptr:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64x!tt.ptr<f32>, [[row_layout]]>
// CHECK: [[load_mask:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xi1, [[row_layout]]>
// CHECK: [[load_other:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xf32, [[row_layout]]>
// CHECK: [[load_val:%.*]] = tt.load [[load_ptr]], [[load_mask]], [[load_other]] {{.*}} : tensor<64x64xf32, [[row_layout]]>
// CHECK: [[store_ptr:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64x!tt.ptr<f32>, [[col_layout]]>
// CHECK: [[store_val:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xf32, [[col_layout]]>
// CHECK: [[store_mask:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xi1, [[col_layout]]>
// CHECK: tt.store [[store_ptr]], [[store_val]], [[store_mask]]
func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%arg1: i32 {tt.divisibility = 16 : i32},
%arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%arg3: i32 {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked0>
%1 = tt.view %0 : (tensor<64xi32, #blocked0>) -> tensor<64x1xi32, #blocked1>
%2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1>
%3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1>
%4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%5 = tt.getelementptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>
%6 = tt.view %0 : (tensor<64xi32, #blocked0>) -> tensor<1x64xi32, #blocked2>
%7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
%8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
%9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
%10 = tt.getelementptr %7, %9 : tensor<64x64x!tt.ptr<f32>, #blocked1>
%11 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%12 = tt.getelementptr %11, %1 : tensor<64x1x!tt.ptr<f32>, #blocked1>
%13 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, #blocked2>
%14 = arith.muli %6, %13 : tensor<1x64xi32, #blocked2>
%15 = tt.broadcast %12 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
%16 = tt.broadcast %14 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
%17 = triton_gpu.convert_layout %16 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
%18 = tt.getelementptr %15, %17 : tensor<64x64x!tt.ptr<f32>, #blocked1>
%19 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked1>
tt.store %18, %19, %cst, : tensor<64x64xf32, #blocked1>
return
}