[OPTIMIZER] Added memory coalescing pass (#31)
This commit is contained in:
@@ -6,6 +6,8 @@
|
||||
namespace mlir {
|
||||
std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages = 2);
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUCoalescePass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUCombineOpsPass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUVerifier();
|
||||
|
@@ -23,6 +23,24 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
|
||||
];
|
||||
}
|
||||
|
||||
def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> {
|
||||
let summary = "coalesce";
|
||||
|
||||
let description = [{
|
||||
TODO
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUCoalescePass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"numWarps", "num-warps",
|
||||
"int32_t", /*default*/"4",
|
||||
"number of warps">
|
||||
];
|
||||
}
|
||||
|
||||
def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> {
|
||||
let summary = "combine triton gpu ops";
|
||||
|
||||
|
@@ -94,7 +94,8 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
AxisInfo curr;
|
||||
// This preserves the input axes (e.g., cast):
|
||||
if (llvm::isa<arith::ExtSIOp, arith::ExtUIOp, arith::TruncIOp,
|
||||
triton::PtrToIntOp, triton::IntToPtrOp>(op))
|
||||
triton::PtrToIntOp, triton::IntToPtrOp,
|
||||
triton::gpu::ConvertLayoutOp>(op))
|
||||
curr = operands[0]->getValue();
|
||||
// Constant ranges
|
||||
if (triton::MakeRangeOp make_range =
|
||||
|
@@ -3,6 +3,7 @@ mlir_tablegen(TritonGPUCombine.inc -gen-rewriters)
|
||||
add_public_tablegen_target(TritonGPUCombineIncGen)
|
||||
|
||||
add_mlir_dialect_library(TritonGPUTransforms
|
||||
Coalesce.cpp
|
||||
Combine.cpp
|
||||
Pipeline.cpp
|
||||
Verifier.cpp
|
||||
|
108
lib/Dialect/TritonGPU/Transforms/Coalesce.cpp
Normal file
108
lib/Dialect/TritonGPU/Transforms/Coalesce.cpp
Normal file
@@ -0,0 +1,108 @@
|
||||
#include "triton/Analysis/AxisInfo.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
#include <numeric>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
|
||||
Attribute getCoalescedEncoding(AxisInfoAnalysis &axisInfo, Value ptr) {
|
||||
auto origType = ptr.getType().cast<RankedTensorType>();
|
||||
// Get the shape of the tensor.
|
||||
size_t rank = origType.getRank();
|
||||
AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue();
|
||||
// Layout order in decreasing order of contiguity
|
||||
SmallVector<unsigned, 4> order(rank);
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
auto contiguity = info.getContiguity();
|
||||
std::sort(order.begin(), order.end(), [&](unsigned x, unsigned y) {
|
||||
return contiguity[x] > contiguity[y];
|
||||
});
|
||||
// Thread tile size depends on memory alignment
|
||||
SmallVector<unsigned, 4> sizePerThread(rank, 1);
|
||||
PointerType ptrType = origType.getElementType().cast<PointerType>();
|
||||
unsigned numBits = ptrType.getPointeeType().getIntOrFloatBitWidth();
|
||||
unsigned maxMultiple = info.getDivisibility(order[0]);
|
||||
unsigned maxContig = info.getContiguity(order[0]);
|
||||
unsigned alignment = std::min(maxMultiple, maxContig);
|
||||
unsigned perThread = std::min(alignment, 128 / numBits);
|
||||
sizePerThread[order[0]] = perThread;
|
||||
// create encoding
|
||||
Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get(
|
||||
&getContext(), origType.getShape(), sizePerThread, order,
|
||||
this->numWarps);
|
||||
return encoding;
|
||||
}
|
||||
|
||||
std::function<Type(Type)> getTypeConverter(AxisInfoAnalysis &axisInfo,
|
||||
Value ptr) {
|
||||
Attribute encoding = getCoalescedEncoding(axisInfo, ptr);
|
||||
return [encoding](Type _type) {
|
||||
RankedTensorType type = _type.cast<RankedTensorType>();
|
||||
return RankedTensorType::get(type.getShape(), type.getElementType(),
|
||||
encoding);
|
||||
};
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void coalesceOp(AxisInfoAnalysis &axisInfo, Operation *op, Value ptr,
|
||||
OpBuilder builder) {
|
||||
RankedTensorType ty = ptr.getType().template dyn_cast<RankedTensorType>();
|
||||
if (!ty)
|
||||
return;
|
||||
AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue();
|
||||
auto convertType = getTypeConverter(axisInfo, ptr);
|
||||
// convert operands
|
||||
SmallVector<Value, 4> newArgs;
|
||||
for (auto v : op->getOperands())
|
||||
newArgs.push_back(builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), convertType(v.getType()), v));
|
||||
// convert output types
|
||||
SmallVector<Type, 4> newTypes;
|
||||
for (auto t : op->getResultTypes())
|
||||
newTypes.push_back(convertType(t));
|
||||
// construct new op with the new encoding
|
||||
Operation *newOp =
|
||||
builder.create<T>(op->getLoc(), newTypes, newArgs, op->getAttrs());
|
||||
// cast the results back to the original layout
|
||||
for (size_t i = 0; i < op->getNumResults(); i++) {
|
||||
auto newResult = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), op->getResult(i).getType(), newOp->getResult(i));
|
||||
op->getResult(i).replaceAllUsesWith(newResult);
|
||||
}
|
||||
op->erase();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
Operation *op = getOperation();
|
||||
// Run axis info analysis
|
||||
AxisInfoAnalysis axisInfo(&getContext());
|
||||
axisInfo.run(op);
|
||||
OpBuilder builder(op);
|
||||
|
||||
// For each memory op that has a layout L1:
|
||||
// 1. Create a coalesced memory layout L2 of the pointer operands
|
||||
// 2. Convert all operands from layout L1 to layout L2
|
||||
// 3. Create a new memory op that consumes these operands and
|
||||
// produces a tensor with layout L2
|
||||
// 4. Convert the output of this new memory op back to L1
|
||||
// 5. Replace all the uses of the original memory op by the new one
|
||||
op->walk([&](Operation *curr) {
|
||||
OpBuilder::InsertionGuard g(builder);
|
||||
builder.setInsertionPoint(curr);
|
||||
if (auto load = dyn_cast<triton::LoadOp>(curr))
|
||||
coalesceOp<triton::LoadOp>(axisInfo, curr, load.ptr(), builder);
|
||||
if (auto store = dyn_cast<triton::StoreOp>(curr))
|
||||
coalesceOp<triton::StoreOp>(axisInfo, curr, store.ptr(), builder);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<Pass> mlir::createTritonGPUCoalescePass() {
|
||||
return std::make_unique<CoalescePass>();
|
||||
}
|
@@ -5,8 +5,8 @@ import triton.language as tl
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, stride_xm, stride_xn,
|
||||
Z, stride_zm, stride_zn,
|
||||
def kernel(X, stride_xm,
|
||||
Z, stride_zn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
|
||||
off_m = tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, BLOCK_N)
|
||||
@@ -15,5 +15,5 @@ def kernel(X, stride_xm, stride_xn,
|
||||
tl.store(Zs, tl.load(Xs))
|
||||
|
||||
|
||||
ret = triton.compile(kernel, "*fp32,i32,i32,*fp32,i32,i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}, output="ttgir")
|
||||
ret = triton.compile(kernel, "*fp32,i32,*fp32,i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}, output="ttgir")
|
||||
print(ret)
|
||||
|
46
test/TritonGPU/coalesce.mlir
Normal file
46
test/TritonGPU/coalesce.mlir
Normal file
@@ -0,0 +1,46 @@
|
||||
// RUN: triton-opt %s -split-input-file -tritongpu-coalesce -canonicalize -tritongpu-verifier | FileCheck %s
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
|
||||
|
||||
|
||||
// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||
// CHECK: [[load_ptr:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64x!tt.ptr<f32>, [[row_layout]]>
|
||||
// CHECK: [[load_mask:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xi1, [[row_layout]]>
|
||||
// CHECK: [[load_other:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xf32, [[row_layout]]>
|
||||
// CHECK: [[load_val:%.*]] = tt.load [[load_ptr]], [[load_mask]], [[load_other]] {{.*}} : tensor<64x64xf32, [[row_layout]]>
|
||||
// CHECK: [[store_ptr:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64x!tt.ptr<f32>, [[col_layout]]>
|
||||
// CHECK: [[store_val:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xf32, [[col_layout]]>
|
||||
// CHECK: [[store_mask:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xi1, [[col_layout]]>
|
||||
// CHECK: tt.store [[store_ptr]], [[store_val]], [[store_mask]]
|
||||
func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%arg1: i32 {tt.divisibility = 16 : i32},
|
||||
%arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%arg3: i32 {tt.divisibility = 16 : i32}) {
|
||||
%cst = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
|
||||
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
|
||||
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked0>
|
||||
%1 = tt.view %0 : (tensor<64xi32, #blocked0>) -> tensor<64x1xi32, #blocked1>
|
||||
%2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1>
|
||||
%3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1>
|
||||
%4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%5 = tt.getelementptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%6 = tt.view %0 : (tensor<64xi32, #blocked0>) -> tensor<1x64xi32, #blocked2>
|
||||
%7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
||||
%9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||
%10 = tt.getelementptr %7, %9 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%11 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%12 = tt.getelementptr %11, %1 : tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%13 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, #blocked2>
|
||||
%14 = arith.muli %6, %13 : tensor<1x64xi32, #blocked2>
|
||||
%15 = tt.broadcast %12 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%16 = tt.broadcast %14 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
||||
%17 = triton_gpu.convert_layout %16 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||
%18 = tt.getelementptr %15, %17 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%19 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked1>
|
||||
tt.store %18, %19, %cst, : tensor<64x64xf32, #blocked1>
|
||||
return
|
||||
}
|
Reference in New Issue
Block a user