[Triton-MLIR] add_volta_warpsPerTile (#907)

This commit is contained in:
ben-zhang-609
2022-11-24 09:44:29 +08:00
committed by GitHub
parent 8925c2cd11
commit b688f7b7b8

View File

@@ -561,6 +561,7 @@ public:
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
// //
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
namespace {
static int computeCapabilityToMMAVersion(int computeCapability) { static int computeCapabilityToMMAVersion(int computeCapability) {
if (computeCapability < 80) { if (computeCapability < 80) {
return 1; return 1;
@@ -575,16 +576,68 @@ static int computeCapabilityToMMAVersion(int computeCapability) {
static SmallVector<int64_t, 2> static SmallVector<int64_t, 2>
mmaVersionToShapePerWarp(int version, const ArrayRef<int64_t> &shape, mmaVersionToShapePerWarp(int version, const ArrayRef<int64_t> &shape,
int numWarps) { int numWarps) {
if (version == 1) { if (version == 1)
return {16, 16}; return {16, 16};
} else if (version == 2) { else if (version == 2)
return {16, 8}; return {16, 8};
} else { else {
assert(false && "version not supported"); assert(false && "version not supported");
return {0, 0}; return {0, 0};
} }
} }
template <int version>
SmallVector<unsigned, 2> warpsPerTile(const ArrayRef<int64_t> shape,
int numWarps);
template <>
SmallVector<unsigned, 2> warpsPerTile<1>(const ArrayRef<int64_t> shape,
int numWarps) {
SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp =
mmaVersionToShapePerWarp(1, shape, numWarps);
bool changed = false;
do {
changed = false;
if (ret[0] * ret[1] < numWarps) {
ret[0] = std::clamp<unsigned>(ret[0] * 2, 1, shape[0] / shapePerWarp[0]);
changed = true;
}
if (ret[0] * ret[1] < numWarps) {
ret[1] = std::clamp<unsigned>(ret[1] * 2, 1, shape[1] / shapePerWarp[1]);
changed = true;
}
} while (changed);
return ret;
}
template <>
SmallVector<unsigned, 2> warpsPerTile<2>(const ArrayRef<int64_t> shape,
int numWarps) {
SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp =
mmaVersionToShapePerWarp(2, shape, numWarps);
// TODO (@daadaada): double-check.
// original logic in
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
// seems buggy for shape = [32, 16] ?
do {
if (ret[0] * ret[1] >= numWarps)
break;
if (shape[0] / shapePerWarp[0] / ret[0] >=
shape[1] / (shapePerWarp[1] * 2) / ret[1]) {
if (ret[0] < shape[0] / shapePerWarp[0]) {
ret[0] *= 2;
} else
ret[1] *= 2;
} else {
ret[1] *= 2;
}
} while (true);
return ret;
}
} // namespace
class BlockedToMMA : public mlir::RewritePattern { class BlockedToMMA : public mlir::RewritePattern {
int computeCapability; int computeCapability;
@@ -593,34 +646,17 @@ public:
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context), : mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context),
computeCapability(computeCapability) {} computeCapability(computeCapability) {}
static SmallVector<unsigned, 2> static SmallVector<unsigned, 2> getWarpsPerTile(const ArrayRef<int64_t> shape,
getWarpsPerTile(const ArrayRef<int64_t> &shape, int version, int numWarps) { int version, int numWarps) {
assert(version == 2); switch (version) {
// TODO: Handle one warp per row for fused matmuls case 1:
// TODO: unsigned -> int64_t to keep things uniform return warpsPerTile<1>(shape, numWarps);
SmallVector<unsigned, 2> ret = {1, 1}; case 2:
SmallVector<int64_t, 2> shapePerWarp = return warpsPerTile<2>(shape, numWarps);
mmaVersionToShapePerWarp(version, shape, numWarps); default:
bool changed = false; assert(false && "not supported version");
// TODO (@daadaada): double-check. return {0, 0};
// original logic in }
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
// seems buggy for shape = [32, 16] ?
do {
changed = false;
if (ret[0] * ret[1] >= numWarps)
break;
if (shape[0] / shapePerWarp[0] / ret[0] >=
shape[1] / (shapePerWarp[1] * 2) / ret[1]) {
if (ret[0] < shape[0] / shapePerWarp[0]) {
ret[0] *= 2;
} else
ret[1] *= 2;
} else {
ret[1] *= 2;
}
} while (true);
return ret;
} }
mlir::LogicalResult mlir::LogicalResult