[Triton-MLIR] Add compute capability (#902)

add compute capability from python frontend to backend.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
This commit is contained in:
ben-zhang-609
2022-11-23 03:08:23 +08:00
committed by GitHub
parent 2afebcd79b
commit 07786dc932
7 changed files with 123 additions and 83 deletions

View File

@@ -15,6 +15,7 @@
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include <memory>
@@ -560,11 +561,37 @@ public:
// -----------------------------------------------------------------------------
//
// -----------------------------------------------------------------------------
static int computeCapabilityToMMAVersion(int computeCapability) {
if (computeCapability < 80) {
return 1;
} else if (computeCapability < 90) {
return 2;
} else {
assert(false && "computeCapability > 90 not supported");
return 0;
}
}
static SmallVector<int64_t, 2>
mmaVersionToShapePerWarp(int version, const ArrayRef<int64_t> &shape,
int numWarps) {
if (version == 1) {
return {16, 16};
} else if (version == 2) {
return {16, 8};
} else {
assert(false && "version not supported");
return {0, 0};
}
}
class BlockedToMMA : public mlir::RewritePattern {
int computeCapability;
public:
BlockedToMMA(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context) {}
BlockedToMMA(mlir::MLIRContext *context, int computeCapability)
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context),
computeCapability(computeCapability) {}
static SmallVector<unsigned, 2>
getWarpsPerTile(const ArrayRef<int64_t> &shape, int version, int numWarps) {
@@ -572,7 +599,8 @@ public:
// TODO: Handle one warp per row for fused matmuls
// TODO: unsigned -> int64_t to keep things uniform
SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
SmallVector<int64_t, 2> shapePerWarp =
mmaVersionToShapePerWarp(version, shape, numWarps);
bool changed = false;
// TODO (@daadaada): double-check.
// original logic in
@@ -615,11 +643,12 @@ public:
auto retShape = oldRetType.getShape();
auto mod = op->getParentOfType<mlir::ModuleOp>();
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
auto newRetType =
RankedTensorType::get(retShape, oldRetType.getElementType(),
triton::gpu::MmaEncodingAttr::get(
oldRetType.getContext(), 2,
getWarpsPerTile(retShape, 2, numWarps)));
int version = computeCapabilityToMMAVersion(computeCapability);
auto newRetType = RankedTensorType::get(
retShape, oldRetType.getElementType(),
triton::gpu::MmaEncodingAttr::get(
oldRetType.getContext(), version,
getWarpsPerTile(retShape, version, numWarps)));
// convert accumulator
auto oldAcc = dotOp.getOperand(2);
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
@@ -656,6 +685,10 @@ public:
class TritonGPUCombineOpsPass
: public TritonGPUCombineOpsBase<TritonGPUCombineOpsPass> {
public:
TritonGPUCombineOpsPass() = default;
TritonGPUCombineOpsPass(int computeCapability) {
this->computeCapability = computeCapability;
}
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();
@@ -667,7 +700,7 @@ public:
patterns.add<RematerializeBackward>(context);
patterns.add<RematerializeForward>(context);
patterns.add<MoveConvertOutOfLoop>(context);
patterns.add<BlockedToMMA>(context);
patterns.add<BlockedToMMA>(context, computeCapability);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
signalPassFailure();
@@ -675,6 +708,7 @@ public:
}
};
std::unique_ptr<Pass> mlir::createTritonGPUCombineOpsPass() {
return std::make_unique<TritonGPUCombineOpsPass>();
std::unique_ptr<Pass>
mlir::createTritonGPUCombineOpsPass(int computeCapability) {
return std::make_unique<TritonGPUCombineOpsPass>(computeCapability);
}