[OPTIMIZER] Coalesce pass no longer takes a num-warps
argument (#99)
Improved design to avoid inconsistent `num-warps` value between the pass and the parent module of the operation it processes.
This commit is contained in:
@@ -18,6 +18,18 @@ def TritonGPU_Dialect : Dialect {
|
|||||||
"triton::TritonDialect",
|
"triton::TritonDialect",
|
||||||
"mlir::gpu::GPUDialect"
|
"mlir::gpu::GPUDialect"
|
||||||
];
|
];
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
static std::string getNumWarpsAttrName() { return "triton_gpu.num-warps"; }
|
||||||
|
static int getNumWarps(ModuleOp mod) {
|
||||||
|
if(!mod->hasAttr("triton_gpu.num-warps"))
|
||||||
|
llvm::report_fatal_error(
|
||||||
|
"TritonGPU module should contain a triton_gpu.num-warps attribute");
|
||||||
|
return mod->getAttr("triton_gpu.num-warps").cast<IntegerAttr>().getInt();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@@ -33,12 +33,6 @@ def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> {
|
|||||||
let constructor = "mlir::createTritonGPUCoalescePass()";
|
let constructor = "mlir::createTritonGPUCoalescePass()";
|
||||||
|
|
||||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
|
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
|
||||||
|
|
||||||
let options = [
|
|
||||||
Option<"numWarps", "num-warps",
|
|
||||||
"int32_t", /*default*/"4",
|
|
||||||
"number of warps">
|
|
||||||
];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> {
|
def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> {
|
||||||
|
@@ -795,24 +795,6 @@ private:
|
|||||||
AxisInfoAnalysis &AxisAnalysisPass;
|
AxisInfoAnalysis &AxisAnalysisPass;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Extract numWarps information from TritonGPU module, return 0 if failed.
|
|
||||||
// This is a naive implementation, it assumes that all the blocked layout should
|
|
||||||
// have the same numWarps setting in a module, it just find a blocked layout
|
|
||||||
// encoding and return the warpsPerCTA field.
|
|
||||||
int extractNumWarps(mlir::ModuleOp module) {
|
|
||||||
int numWarps{};
|
|
||||||
if (module->hasAttr(AttrNumWarpsName))
|
|
||||||
numWarps = module->getAttr(AttrNumWarpsName)
|
|
||||||
.dyn_cast<IntegerAttr>()
|
|
||||||
.getValue()
|
|
||||||
.getZExtValue();
|
|
||||||
else
|
|
||||||
llvm::report_fatal_error(
|
|
||||||
"TritonGPU module should contain a triton_gpu.num-warps attribute");
|
|
||||||
|
|
||||||
return numWarps;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct BroadcastOpConversion
|
struct BroadcastOpConversion
|
||||||
: public ConvertTritonGPUOpToLLVMPattern<triton::BroadcastOp> {
|
: public ConvertTritonGPUOpToLLVMPattern<triton::BroadcastOp> {
|
||||||
using ConvertTritonGPUOpToLLVMPattern<
|
using ConvertTritonGPUOpToLLVMPattern<
|
||||||
@@ -1319,7 +1301,7 @@ public:
|
|||||||
|
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
|
|
||||||
int numWarps = extractNumWarps(mod);
|
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||||
|
|
||||||
auto axisAnalysis = runAxisAnalysis(mod);
|
auto axisAnalysis = runAxisAnalysis(mod);
|
||||||
|
|
||||||
|
@@ -11,7 +11,8 @@ using namespace mlir::triton;
|
|||||||
|
|
||||||
struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||||
|
|
||||||
Attribute getCoalescedEncoding(AxisInfoAnalysis &axisInfo, Value ptr) {
|
Attribute getCoalescedEncoding(AxisInfoAnalysis &axisInfo, Value ptr,
|
||||||
|
int numWarps) {
|
||||||
auto origType = ptr.getType().cast<RankedTensorType>();
|
auto origType = ptr.getType().cast<RankedTensorType>();
|
||||||
// Get the shape of the tensor.
|
// Get the shape of the tensor.
|
||||||
size_t rank = origType.getRank();
|
size_t rank = origType.getRank();
|
||||||
@@ -36,14 +37,13 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
|||||||
std::iota(dims.begin(), dims.end(), 0);
|
std::iota(dims.begin(), dims.end(), 0);
|
||||||
// create encoding
|
// create encoding
|
||||||
Attribute encoding = triton::gpu::BlockedEncodingAttr::get(
|
Attribute encoding = triton::gpu::BlockedEncodingAttr::get(
|
||||||
&getContext(), origType.getShape(), sizePerThread, order,
|
&getContext(), origType.getShape(), sizePerThread, order, numWarps);
|
||||||
this->numWarps);
|
|
||||||
return encoding;
|
return encoding;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::function<Type(Type)> getTypeConverter(AxisInfoAnalysis &axisInfo,
|
std::function<Type(Type)> getTypeConverter(AxisInfoAnalysis &axisInfo,
|
||||||
Value ptr) {
|
Value ptr, int numWarps) {
|
||||||
Attribute encoding = getCoalescedEncoding(axisInfo, ptr);
|
Attribute encoding = getCoalescedEncoding(axisInfo, ptr, numWarps);
|
||||||
return [encoding](Type _type) {
|
return [encoding](Type _type) {
|
||||||
RankedTensorType type = _type.cast<RankedTensorType>();
|
RankedTensorType type = _type.cast<RankedTensorType>();
|
||||||
return RankedTensorType::get(type.getShape(), type.getElementType(),
|
return RankedTensorType::get(type.getShape(), type.getElementType(),
|
||||||
@@ -57,8 +57,11 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
|||||||
RankedTensorType ty = ptr.getType().template dyn_cast<RankedTensorType>();
|
RankedTensorType ty = ptr.getType().template dyn_cast<RankedTensorType>();
|
||||||
if (!ty)
|
if (!ty)
|
||||||
return;
|
return;
|
||||||
|
auto mod = op->getParentOfType<ModuleOp>();
|
||||||
|
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||||
|
|
||||||
AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue();
|
AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue();
|
||||||
auto convertType = getTypeConverter(axisInfo, ptr);
|
auto convertType = getTypeConverter(axisInfo, ptr, numWarps);
|
||||||
// convert operands
|
// convert operands
|
||||||
SmallVector<Value, 4> newArgs;
|
SmallVector<Value, 4> newArgs;
|
||||||
for (auto v : op->getOperands())
|
for (auto v : op->getOperands())
|
||||||
|
@@ -4,6 +4,8 @@
|
|||||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
|
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
|
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
|
||||||
|
|
||||||
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||||
|
|
||||||
|
|
||||||
// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||||
// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
|
// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||||
@@ -44,3 +46,5 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
|||||||
tt.store %18, %19, %cst : tensor<64x64xf32, #blocked1>
|
tt.store %18, %19, %cst : tensor<64x64xf32, #blocked1>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
}
|
Reference in New Issue
Block a user