[Triton-MLIR] add_volta_warpsPerTile (#907)
This commit is contained in:
@@ -561,6 +561,7 @@ public:
|
||||
// -----------------------------------------------------------------------------
|
||||
//
|
||||
// -----------------------------------------------------------------------------
|
||||
namespace {
|
||||
static int computeCapabilityToMMAVersion(int computeCapability) {
|
||||
if (computeCapability < 80) {
|
||||
return 1;
|
||||
@@ -575,16 +576,68 @@ static int computeCapabilityToMMAVersion(int computeCapability) {
|
||||
static SmallVector<int64_t, 2>
|
||||
mmaVersionToShapePerWarp(int version, const ArrayRef<int64_t> &shape,
|
||||
int numWarps) {
|
||||
if (version == 1) {
|
||||
if (version == 1)
|
||||
return {16, 16};
|
||||
} else if (version == 2) {
|
||||
else if (version == 2)
|
||||
return {16, 8};
|
||||
} else {
|
||||
else {
|
||||
assert(false && "version not supported");
|
||||
return {0, 0};
|
||||
}
|
||||
}
|
||||
|
||||
template <int version>
|
||||
SmallVector<unsigned, 2> warpsPerTile(const ArrayRef<int64_t> shape,
|
||||
int numWarps);
|
||||
|
||||
template <>
|
||||
SmallVector<unsigned, 2> warpsPerTile<1>(const ArrayRef<int64_t> shape,
|
||||
int numWarps) {
|
||||
SmallVector<unsigned, 2> ret = {1, 1};
|
||||
SmallVector<int64_t, 2> shapePerWarp =
|
||||
mmaVersionToShapePerWarp(1, shape, numWarps);
|
||||
bool changed = false;
|
||||
do {
|
||||
changed = false;
|
||||
if (ret[0] * ret[1] < numWarps) {
|
||||
ret[0] = std::clamp<unsigned>(ret[0] * 2, 1, shape[0] / shapePerWarp[0]);
|
||||
changed = true;
|
||||
}
|
||||
if (ret[0] * ret[1] < numWarps) {
|
||||
ret[1] = std::clamp<unsigned>(ret[1] * 2, 1, shape[1] / shapePerWarp[1]);
|
||||
changed = true;
|
||||
}
|
||||
} while (changed);
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <>
|
||||
SmallVector<unsigned, 2> warpsPerTile<2>(const ArrayRef<int64_t> shape,
|
||||
int numWarps) {
|
||||
SmallVector<unsigned, 2> ret = {1, 1};
|
||||
SmallVector<int64_t, 2> shapePerWarp =
|
||||
mmaVersionToShapePerWarp(2, shape, numWarps);
|
||||
// TODO (@daadaada): double-check.
|
||||
// original logic in
|
||||
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
|
||||
// seems buggy for shape = [32, 16] ?
|
||||
do {
|
||||
if (ret[0] * ret[1] >= numWarps)
|
||||
break;
|
||||
if (shape[0] / shapePerWarp[0] / ret[0] >=
|
||||
shape[1] / (shapePerWarp[1] * 2) / ret[1]) {
|
||||
if (ret[0] < shape[0] / shapePerWarp[0]) {
|
||||
ret[0] *= 2;
|
||||
} else
|
||||
ret[1] *= 2;
|
||||
} else {
|
||||
ret[1] *= 2;
|
||||
}
|
||||
} while (true);
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
class BlockedToMMA : public mlir::RewritePattern {
|
||||
int computeCapability;
|
||||
|
||||
@@ -593,34 +646,17 @@ public:
|
||||
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context),
|
||||
computeCapability(computeCapability) {}
|
||||
|
||||
static SmallVector<unsigned, 2>
|
||||
getWarpsPerTile(const ArrayRef<int64_t> &shape, int version, int numWarps) {
|
||||
assert(version == 2);
|
||||
// TODO: Handle one warp per row for fused matmuls
|
||||
// TODO: unsigned -> int64_t to keep things uniform
|
||||
SmallVector<unsigned, 2> ret = {1, 1};
|
||||
SmallVector<int64_t, 2> shapePerWarp =
|
||||
mmaVersionToShapePerWarp(version, shape, numWarps);
|
||||
bool changed = false;
|
||||
// TODO (@daadaada): double-check.
|
||||
// original logic in
|
||||
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
|
||||
// seems buggy for shape = [32, 16] ?
|
||||
do {
|
||||
changed = false;
|
||||
if (ret[0] * ret[1] >= numWarps)
|
||||
break;
|
||||
if (shape[0] / shapePerWarp[0] / ret[0] >=
|
||||
shape[1] / (shapePerWarp[1] * 2) / ret[1]) {
|
||||
if (ret[0] < shape[0] / shapePerWarp[0]) {
|
||||
ret[0] *= 2;
|
||||
} else
|
||||
ret[1] *= 2;
|
||||
} else {
|
||||
ret[1] *= 2;
|
||||
}
|
||||
} while (true);
|
||||
return ret;
|
||||
static SmallVector<unsigned, 2> getWarpsPerTile(const ArrayRef<int64_t> shape,
|
||||
int version, int numWarps) {
|
||||
switch (version) {
|
||||
case 1:
|
||||
return warpsPerTile<1>(shape, numWarps);
|
||||
case 2:
|
||||
return warpsPerTile<2>(shape, numWarps);
|
||||
default:
|
||||
assert(false && "not supported version");
|
||||
return {0, 0};
|
||||
}
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
|
Reference in New Issue
Block a user