[Triton-MLIR] Add compute capability (#902)
add compute capability from python frontend to backend. Co-authored-by: Keren Zhou <kerenzhou@openai.com>
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
@@ -560,11 +561,37 @@ public:
|
||||
// -----------------------------------------------------------------------------
|
||||
//
|
||||
// -----------------------------------------------------------------------------
|
||||
static int computeCapabilityToMMAVersion(int computeCapability) {
|
||||
if (computeCapability < 80) {
|
||||
return 1;
|
||||
} else if (computeCapability < 90) {
|
||||
return 2;
|
||||
} else {
|
||||
assert(false && "computeCapability > 90 not supported");
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
static SmallVector<int64_t, 2>
|
||||
mmaVersionToShapePerWarp(int version, const ArrayRef<int64_t> &shape,
|
||||
int numWarps) {
|
||||
if (version == 1) {
|
||||
return {16, 16};
|
||||
} else if (version == 2) {
|
||||
return {16, 8};
|
||||
} else {
|
||||
assert(false && "version not supported");
|
||||
return {0, 0};
|
||||
}
|
||||
}
|
||||
|
||||
class BlockedToMMA : public mlir::RewritePattern {
|
||||
int computeCapability;
|
||||
|
||||
public:
|
||||
BlockedToMMA(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context) {}
|
||||
BlockedToMMA(mlir::MLIRContext *context, int computeCapability)
|
||||
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context),
|
||||
computeCapability(computeCapability) {}
|
||||
|
||||
static SmallVector<unsigned, 2>
|
||||
getWarpsPerTile(const ArrayRef<int64_t> &shape, int version, int numWarps) {
|
||||
@@ -572,7 +599,8 @@ public:
|
||||
// TODO: Handle one warp per row for fused matmuls
|
||||
// TODO: unsigned -> int64_t to keep things uniform
|
||||
SmallVector<unsigned, 2> ret = {1, 1};
|
||||
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
|
||||
SmallVector<int64_t, 2> shapePerWarp =
|
||||
mmaVersionToShapePerWarp(version, shape, numWarps);
|
||||
bool changed = false;
|
||||
// TODO (@daadaada): double-check.
|
||||
// original logic in
|
||||
@@ -615,11 +643,12 @@ public:
|
||||
auto retShape = oldRetType.getShape();
|
||||
auto mod = op->getParentOfType<mlir::ModuleOp>();
|
||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
auto newRetType =
|
||||
RankedTensorType::get(retShape, oldRetType.getElementType(),
|
||||
triton::gpu::MmaEncodingAttr::get(
|
||||
oldRetType.getContext(), 2,
|
||||
getWarpsPerTile(retShape, 2, numWarps)));
|
||||
int version = computeCapabilityToMMAVersion(computeCapability);
|
||||
auto newRetType = RankedTensorType::get(
|
||||
retShape, oldRetType.getElementType(),
|
||||
triton::gpu::MmaEncodingAttr::get(
|
||||
oldRetType.getContext(), version,
|
||||
getWarpsPerTile(retShape, version, numWarps)));
|
||||
// convert accumulator
|
||||
auto oldAcc = dotOp.getOperand(2);
|
||||
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
@@ -656,6 +685,10 @@ public:
|
||||
class TritonGPUCombineOpsPass
|
||||
: public TritonGPUCombineOpsBase<TritonGPUCombineOpsPass> {
|
||||
public:
|
||||
TritonGPUCombineOpsPass() = default;
|
||||
TritonGPUCombineOpsPass(int computeCapability) {
|
||||
this->computeCapability = computeCapability;
|
||||
}
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp m = getOperation();
|
||||
@@ -667,7 +700,7 @@ public:
|
||||
patterns.add<RematerializeBackward>(context);
|
||||
patterns.add<RematerializeForward>(context);
|
||||
patterns.add<MoveConvertOutOfLoop>(context);
|
||||
patterns.add<BlockedToMMA>(context);
|
||||
patterns.add<BlockedToMMA>(context, computeCapability);
|
||||
|
||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
|
||||
signalPassFailure();
|
||||
@@ -675,6 +708,7 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<Pass> mlir::createTritonGPUCombineOpsPass() {
|
||||
return std::make_unique<TritonGPUCombineOpsPass>();
|
||||
std::unique_ptr<Pass>
|
||||
mlir::createTritonGPUCombineOpsPass(int computeCapability) {
|
||||
return std::make_unique<TritonGPUCombineOpsPass>(computeCapability);
|
||||
}
|
||||
|
Reference in New Issue
Block a user