diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index fb6124d52..c9723f268 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -561,6 +561,7 @@ public: // ----------------------------------------------------------------------------- // // ----------------------------------------------------------------------------- +namespace { static int computeCapabilityToMMAVersion(int computeCapability) { if (computeCapability < 80) { return 1; @@ -575,16 +576,68 @@ static int computeCapabilityToMMAVersion(int computeCapability) { static SmallVector mmaVersionToShapePerWarp(int version, const ArrayRef &shape, int numWarps) { - if (version == 1) { + if (version == 1) return {16, 16}; - } else if (version == 2) { + else if (version == 2) return {16, 8}; - } else { + else { assert(false && "version not supported"); return {0, 0}; } } +template +SmallVector warpsPerTile(const ArrayRef shape, + int numWarps); + +template <> +SmallVector warpsPerTile<1>(const ArrayRef shape, + int numWarps) { + SmallVector ret = {1, 1}; + SmallVector shapePerWarp = + mmaVersionToShapePerWarp(1, shape, numWarps); + bool changed = false; + do { + changed = false; + if (ret[0] * ret[1] < numWarps) { + ret[0] = std::clamp(ret[0] * 2, 1, shape[0] / shapePerWarp[0]); + changed = true; + } + if (ret[0] * ret[1] < numWarps) { + ret[1] = std::clamp(ret[1] * 2, 1, shape[1] / shapePerWarp[1]); + changed = true; + } + } while (changed); + return ret; +} + +template <> +SmallVector warpsPerTile<2>(const ArrayRef shape, + int numWarps) { + SmallVector ret = {1, 1}; + SmallVector shapePerWarp = + mmaVersionToShapePerWarp(2, shape, numWarps); + // TODO (@daadaada): double-check. + // original logic in + // https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252 + // seems buggy for shape = [32, 16] ? + do { + if (ret[0] * ret[1] >= numWarps) + break; + if (shape[0] / shapePerWarp[0] / ret[0] >= + shape[1] / (shapePerWarp[1] * 2) / ret[1]) { + if (ret[0] < shape[0] / shapePerWarp[0]) { + ret[0] *= 2; + } else + ret[1] *= 2; + } else { + ret[1] *= 2; + } + } while (true); + return ret; +} + +} // namespace class BlockedToMMA : public mlir::RewritePattern { int computeCapability; @@ -593,34 +646,17 @@ public: : mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context), computeCapability(computeCapability) {} - static SmallVector - getWarpsPerTile(const ArrayRef &shape, int version, int numWarps) { - assert(version == 2); - // TODO: Handle one warp per row for fused matmuls - // TODO: unsigned -> int64_t to keep things uniform - SmallVector ret = {1, 1}; - SmallVector shapePerWarp = - mmaVersionToShapePerWarp(version, shape, numWarps); - bool changed = false; - // TODO (@daadaada): double-check. - // original logic in - // https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252 - // seems buggy for shape = [32, 16] ? - do { - changed = false; - if (ret[0] * ret[1] >= numWarps) - break; - if (shape[0] / shapePerWarp[0] / ret[0] >= - shape[1] / (shapePerWarp[1] * 2) / ret[1]) { - if (ret[0] < shape[0] / shapePerWarp[0]) { - ret[0] *= 2; - } else - ret[1] *= 2; - } else { - ret[1] *= 2; - } - } while (true); - return ret; + static SmallVector getWarpsPerTile(const ArrayRef shape, + int version, int numWarps) { + switch (version) { + case 1: + return warpsPerTile<1>(shape, numWarps); + case 2: + return warpsPerTile<2>(shape, numWarps); + default: + assert(false && "not supported version"); + return {0, 0}; + } } mlir::LogicalResult