[Triton-MLIR] add_volta_warpsPerTile (#907)

This commit is contained in:
ben-zhang-609
2022-11-24 09:44:29 +08:00
committed by GitHub
parent 8925c2cd11
commit b688f7b7b8

View File

@@ -561,6 +561,7 @@ public:
// -----------------------------------------------------------------------------
//
// -----------------------------------------------------------------------------
namespace {
static int computeCapabilityToMMAVersion(int computeCapability) {
if (computeCapability < 80) {
return 1;
@@ -575,16 +576,68 @@ static int computeCapabilityToMMAVersion(int computeCapability) {
static SmallVector<int64_t, 2>
mmaVersionToShapePerWarp(int version, const ArrayRef<int64_t> &shape,
int numWarps) {
if (version == 1) {
if (version == 1)
return {16, 16};
} else if (version == 2) {
else if (version == 2)
return {16, 8};
} else {
else {
assert(false && "version not supported");
return {0, 0};
}
}
template <int version>
SmallVector<unsigned, 2> warpsPerTile(const ArrayRef<int64_t> shape,
int numWarps);
template <>
SmallVector<unsigned, 2> warpsPerTile<1>(const ArrayRef<int64_t> shape,
int numWarps) {
SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp =
mmaVersionToShapePerWarp(1, shape, numWarps);
bool changed = false;
do {
changed = false;
if (ret[0] * ret[1] < numWarps) {
ret[0] = std::clamp<unsigned>(ret[0] * 2, 1, shape[0] / shapePerWarp[0]);
changed = true;
}
if (ret[0] * ret[1] < numWarps) {
ret[1] = std::clamp<unsigned>(ret[1] * 2, 1, shape[1] / shapePerWarp[1]);
changed = true;
}
} while (changed);
return ret;
}
template <>
SmallVector<unsigned, 2> warpsPerTile<2>(const ArrayRef<int64_t> shape,
int numWarps) {
SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp =
mmaVersionToShapePerWarp(2, shape, numWarps);
// TODO (@daadaada): double-check.
// original logic in
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
// seems buggy for shape = [32, 16] ?
do {
if (ret[0] * ret[1] >= numWarps)
break;
if (shape[0] / shapePerWarp[0] / ret[0] >=
shape[1] / (shapePerWarp[1] * 2) / ret[1]) {
if (ret[0] < shape[0] / shapePerWarp[0]) {
ret[0] *= 2;
} else
ret[1] *= 2;
} else {
ret[1] *= 2;
}
} while (true);
return ret;
}
} // namespace
class BlockedToMMA : public mlir::RewritePattern {
int computeCapability;
@@ -593,34 +646,17 @@ public:
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context),
computeCapability(computeCapability) {}
static SmallVector<unsigned, 2>
getWarpsPerTile(const ArrayRef<int64_t> &shape, int version, int numWarps) {
assert(version == 2);
// TODO: Handle one warp per row for fused matmuls
// TODO: unsigned -> int64_t to keep things uniform
SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp =
mmaVersionToShapePerWarp(version, shape, numWarps);
bool changed = false;
// TODO (@daadaada): double-check.
// original logic in
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
// seems buggy for shape = [32, 16] ?
do {
changed = false;
if (ret[0] * ret[1] >= numWarps)
break;
if (shape[0] / shapePerWarp[0] / ret[0] >=
shape[1] / (shapePerWarp[1] * 2) / ret[1]) {
if (ret[0] < shape[0] / shapePerWarp[0]) {
ret[0] *= 2;
} else
ret[1] *= 2;
} else {
ret[1] *= 2;
}
} while (true);
return ret;
static SmallVector<unsigned, 2> getWarpsPerTile(const ArrayRef<int64_t> shape,
int version, int numWarps) {
switch (version) {
case 1:
return warpsPerTile<1>(shape, numWarps);
case 2:
return warpsPerTile<2>(shape, numWarps);
default:
assert(false && "not supported version");
return {0, 0};
}
}
mlir::LogicalResult