[Triton-MLIR] add_volta_warpsPerTile (#907)
This commit is contained in:
@@ -561,6 +561,7 @@ public:
|
|||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
//
|
//
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
namespace {
|
||||||
static int computeCapabilityToMMAVersion(int computeCapability) {
|
static int computeCapabilityToMMAVersion(int computeCapability) {
|
||||||
if (computeCapability < 80) {
|
if (computeCapability < 80) {
|
||||||
return 1;
|
return 1;
|
||||||
@@ -575,39 +576,52 @@ static int computeCapabilityToMMAVersion(int computeCapability) {
|
|||||||
static SmallVector<int64_t, 2>
|
static SmallVector<int64_t, 2>
|
||||||
mmaVersionToShapePerWarp(int version, const ArrayRef<int64_t> &shape,
|
mmaVersionToShapePerWarp(int version, const ArrayRef<int64_t> &shape,
|
||||||
int numWarps) {
|
int numWarps) {
|
||||||
if (version == 1) {
|
if (version == 1)
|
||||||
return {16, 16};
|
return {16, 16};
|
||||||
} else if (version == 2) {
|
else if (version == 2)
|
||||||
return {16, 8};
|
return {16, 8};
|
||||||
} else {
|
else {
|
||||||
assert(false && "version not supported");
|
assert(false && "version not supported");
|
||||||
return {0, 0};
|
return {0, 0};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class BlockedToMMA : public mlir::RewritePattern {
|
template <int version>
|
||||||
int computeCapability;
|
SmallVector<unsigned, 2> warpsPerTile(const ArrayRef<int64_t> shape,
|
||||||
|
int numWarps);
|
||||||
|
|
||||||
public:
|
template <>
|
||||||
BlockedToMMA(mlir::MLIRContext *context, int computeCapability)
|
SmallVector<unsigned, 2> warpsPerTile<1>(const ArrayRef<int64_t> shape,
|
||||||
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context),
|
int numWarps) {
|
||||||
computeCapability(computeCapability) {}
|
|
||||||
|
|
||||||
static SmallVector<unsigned, 2>
|
|
||||||
getWarpsPerTile(const ArrayRef<int64_t> &shape, int version, int numWarps) {
|
|
||||||
assert(version == 2);
|
|
||||||
// TODO: Handle one warp per row for fused matmuls
|
|
||||||
// TODO: unsigned -> int64_t to keep things uniform
|
|
||||||
SmallVector<unsigned, 2> ret = {1, 1};
|
SmallVector<unsigned, 2> ret = {1, 1};
|
||||||
SmallVector<int64_t, 2> shapePerWarp =
|
SmallVector<int64_t, 2> shapePerWarp =
|
||||||
mmaVersionToShapePerWarp(version, shape, numWarps);
|
mmaVersionToShapePerWarp(1, shape, numWarps);
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
|
do {
|
||||||
|
changed = false;
|
||||||
|
if (ret[0] * ret[1] < numWarps) {
|
||||||
|
ret[0] = std::clamp<unsigned>(ret[0] * 2, 1, shape[0] / shapePerWarp[0]);
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
if (ret[0] * ret[1] < numWarps) {
|
||||||
|
ret[1] = std::clamp<unsigned>(ret[1] * 2, 1, shape[1] / shapePerWarp[1]);
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
} while (changed);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
SmallVector<unsigned, 2> warpsPerTile<2>(const ArrayRef<int64_t> shape,
|
||||||
|
int numWarps) {
|
||||||
|
SmallVector<unsigned, 2> ret = {1, 1};
|
||||||
|
SmallVector<int64_t, 2> shapePerWarp =
|
||||||
|
mmaVersionToShapePerWarp(2, shape, numWarps);
|
||||||
// TODO (@daadaada): double-check.
|
// TODO (@daadaada): double-check.
|
||||||
// original logic in
|
// original logic in
|
||||||
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
|
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
|
||||||
// seems buggy for shape = [32, 16] ?
|
// seems buggy for shape = [32, 16] ?
|
||||||
do {
|
do {
|
||||||
changed = false;
|
|
||||||
if (ret[0] * ret[1] >= numWarps)
|
if (ret[0] * ret[1] >= numWarps)
|
||||||
break;
|
break;
|
||||||
if (shape[0] / shapePerWarp[0] / ret[0] >=
|
if (shape[0] / shapePerWarp[0] / ret[0] >=
|
||||||
@@ -621,6 +635,28 @@ public:
|
|||||||
}
|
}
|
||||||
} while (true);
|
} while (true);
|
||||||
return ret;
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
class BlockedToMMA : public mlir::RewritePattern {
|
||||||
|
int computeCapability;
|
||||||
|
|
||||||
|
public:
|
||||||
|
BlockedToMMA(mlir::MLIRContext *context, int computeCapability)
|
||||||
|
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context),
|
||||||
|
computeCapability(computeCapability) {}
|
||||||
|
|
||||||
|
static SmallVector<unsigned, 2> getWarpsPerTile(const ArrayRef<int64_t> shape,
|
||||||
|
int version, int numWarps) {
|
||||||
|
switch (version) {
|
||||||
|
case 1:
|
||||||
|
return warpsPerTile<1>(shape, numWarps);
|
||||||
|
case 2:
|
||||||
|
return warpsPerTile<2>(shape, numWarps);
|
||||||
|
default:
|
||||||
|
assert(false && "not supported version");
|
||||||
|
return {0, 0};
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::LogicalResult
|
mlir::LogicalResult
|
||||||
|
Reference in New Issue
Block a user