[OPTIMIZER] Minor bugfixes that affected matmul codegen performance (#834)

This commit is contained in:
Philippe Tillet
2022-11-02 22:58:09 -07:00
committed by GitHub
parent 847a318a03
commit 91a9773b38
4 changed files with 59 additions and 29 deletions

View File

@@ -159,6 +159,16 @@ ChangeResult AxisInfoAnalysis::visitOperation(
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(), curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
newContiguity, newDivisibility, newConstancy); newContiguity, newDivisibility, newConstancy);
} }
// TODO: All other binary ops
if (llvm::isa<arith::AndIOp, arith::OrIOp>(op)) {
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; };
auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; };
auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d) {
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
};
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
newContiguity, newDivisibility, newConstancy);
}
// Splat // Splat
if (llvm::isa<triton::SplatOp>(op)) { if (llvm::isa<triton::SplatOp>(op)) {
Type _retTy = *op->result_type_begin(); Type _retTy = *op->result_type_begin();
@@ -200,7 +210,8 @@ ChangeResult AxisInfoAnalysis::visitOperation(
for (int d = 0; d < retTy.getRank(); ++d) { for (int d = 0; d < retTy.getRank(); ++d) {
contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d)); contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d));
divisibility.push_back(opInfo.getDivisibility(d)); divisibility.push_back(opInfo.getDivisibility(d));
constancy.push_back(opShape[d] == 1 ? retShape[d] : 1); constancy.push_back(opShape[d] == 1 ? retShape[d]
: opInfo.getConstancy(d));
} }
curr = AxisInfo(contiguity, divisibility, constancy); curr = AxisInfo(contiguity, divisibility, constancy);
} }

View File

@@ -693,7 +693,8 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
TypeConverter *typeConverter, TypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Location loc) { ConversionPatternRewriter &rewriter, Location loc) {
auto tensorTy = resType.cast<RankedTensorType>(); auto tensorTy = resType.cast<RankedTensorType>();
if (tensorTy.getEncoding().isa<BlockedEncodingAttr>()) { if (tensorTy.getEncoding().isa<BlockedEncodingAttr>() ||
tensorTy.getEncoding().isa<SliceEncodingAttr>()) {
auto tensorTy = resType.cast<RankedTensorType>(); auto tensorTy = resType.cast<RankedTensorType>();
auto srcType = typeConverter->convertType(elemType); auto srcType = typeConverter->convertType(elemType);
auto llSrc = bitcast(constVal, srcType); auto llSrc = bitcast(constVal, srcType);

View File

@@ -533,6 +533,35 @@ public:
BlockedToMMA(mlir::MLIRContext *context) BlockedToMMA(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context) {} : mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context) {}
static SmallVector<unsigned, 2>
getWarpsPerTile(const ArrayRef<int64_t> &shape, int version, int numWarps) {
assert(version == 2);
// TODO: Handle one warp per row for fused matmuls
// TODO: unsigned -> int64_t to keep things uniform
SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
bool changed = false;
// TODO (@daadaada): double-check.
// original logic in
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
// seems buggy for shape = [32, 16] ?
do {
changed = false;
if (ret[0] * ret[1] >= numWarps)
break;
if (shape[0] / shapePerWarp[0] / ret[0] >=
shape[1] / (shapePerWarp[1] * 2) / ret[1]) {
if (ret[0] < shape[0] / shapePerWarp[0]) {
ret[0] *= 2;
} else
ret[1] *= 2;
} else {
ret[1] *= 2;
}
} while (true);
return ret;
}
mlir::LogicalResult mlir::LogicalResult
matchAndRewrite(mlir::Operation *op, matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override { mlir::PatternRewriter &rewriter) const override {
@@ -541,13 +570,20 @@ public:
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>(); auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
if (oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>()) if (oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
return failure(); return failure();
// TODO: compute warpsPerCTA // get MMA encoding for the given number of warps
auto newRetType = RankedTensorType::get( auto retShape = oldRetType.getShape();
oldRetType.getShape(), oldRetType.getElementType(), auto mod = op->getParentOfType<mlir::ModuleOp>();
triton::gpu::MmaEncodingAttr::get(oldRetType.getContext(), 2, {2, 2})); int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
auto newRetType =
RankedTensorType::get(retShape, oldRetType.getElementType(),
triton::gpu::MmaEncodingAttr::get(
oldRetType.getContext(), 2,
getWarpsPerTile(retShape, 2, numWarps)));
// convert accumulator
auto oldAcc = dotOp.getOperand(2); auto oldAcc = dotOp.getOperand(2);
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>( auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
oldAcc.getLoc(), newRetType, oldAcc); oldAcc.getLoc(), newRetType, oldAcc);
// convert output
auto newDot = rewriter.create<triton::DotOp>( auto newDot = rewriter.create<triton::DotOp>(
dotOp.getLoc(), newRetType, dotOp.getOperand(0), dotOp.getOperand(1), dotOp.getLoc(), newRetType, dotOp.getOperand(0), dotOp.getOperand(1),
newAcc, dotOp.allowTF32(), dotOp.transA(), dotOp.transB()); newAcc, dotOp.allowTF32(), dotOp.transA(), dotOp.transB());

View File

@@ -157,15 +157,6 @@ import triton.language as tl
@triton.autotune( @triton.autotune(
configs=[ configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
], ],
key=['M', 'N', 'K'], key=['M', 'N', 'K'],
) )
@@ -318,13 +309,13 @@ else:
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
x_vals=[ x_vals=[
128 * i for i in range(2, 33) 8192
], # different possible values for `x_name` ], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot line_arg='provider', # argument name whose value corresponds to a different line in the plot
# possible values for `line_arg`` # possible values for `line_arg``
line_vals=['cublas', 'cublas + relu', 'triton', 'triton + relu'], line_vals=['cublas', 'triton'],
# label name for the lines # label name for the lines
line_names=["cuBLAS", "cuBLAS (+ torch.nn.LeakyReLU)", "Triton", "Triton (+ LeakyReLU)"], line_names=["cuBLAS", "Triton"],
# line styles # line styles
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
ylabel="TFLOPS", # label name for the y-axis ylabel="TFLOPS", # label name for the y-axis
@@ -336,18 +327,9 @@ def benchmark(M, N, K, provider):
a = torch.randn((M, K), device='cuda', dtype=torch.float16) a = torch.randn((M, K), device='cuda', dtype=torch.float16)
b = torch.randn((K, N), device='cuda', dtype=torch.float16) b = torch.randn((K, N), device='cuda', dtype=torch.float16)
if provider == 'cublas': if provider == 'cublas':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b)) ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), rep=100)
if provider == 'triton': if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b)) ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), rep=100)
if provider == 'cublas + relu':
torch_relu = torch.nn.ReLU(inplace=True)
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: torch_relu(torch.matmul(a, b))
)
if provider == 'triton + relu':
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: matmul(a, b, activation=leaky_relu)
)
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms) return perf(ms), perf(max_ms), perf(min_ms)