[OPTIMIZER] Minor bugfixes that affected matmul codegen performance (#834)
This commit is contained in:
@@ -159,6 +159,16 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
|||||||
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
|
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
|
||||||
newContiguity, newDivisibility, newConstancy);
|
newContiguity, newDivisibility, newConstancy);
|
||||||
}
|
}
|
||||||
|
// TODO: All other binary ops
|
||||||
|
if (llvm::isa<arith::AndIOp, arith::OrIOp>(op)) {
|
||||||
|
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; };
|
||||||
|
auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; };
|
||||||
|
auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||||
|
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
|
||||||
|
};
|
||||||
|
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
|
||||||
|
newContiguity, newDivisibility, newConstancy);
|
||||||
|
}
|
||||||
// Splat
|
// Splat
|
||||||
if (llvm::isa<triton::SplatOp>(op)) {
|
if (llvm::isa<triton::SplatOp>(op)) {
|
||||||
Type _retTy = *op->result_type_begin();
|
Type _retTy = *op->result_type_begin();
|
||||||
@@ -200,7 +210,8 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
|||||||
for (int d = 0; d < retTy.getRank(); ++d) {
|
for (int d = 0; d < retTy.getRank(); ++d) {
|
||||||
contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d));
|
contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d));
|
||||||
divisibility.push_back(opInfo.getDivisibility(d));
|
divisibility.push_back(opInfo.getDivisibility(d));
|
||||||
constancy.push_back(opShape[d] == 1 ? retShape[d] : 1);
|
constancy.push_back(opShape[d] == 1 ? retShape[d]
|
||||||
|
: opInfo.getConstancy(d));
|
||||||
}
|
}
|
||||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||||
}
|
}
|
||||||
|
@@ -693,7 +693,8 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
|
|||||||
TypeConverter *typeConverter,
|
TypeConverter *typeConverter,
|
||||||
ConversionPatternRewriter &rewriter, Location loc) {
|
ConversionPatternRewriter &rewriter, Location loc) {
|
||||||
auto tensorTy = resType.cast<RankedTensorType>();
|
auto tensorTy = resType.cast<RankedTensorType>();
|
||||||
if (tensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
|
if (tensorTy.getEncoding().isa<BlockedEncodingAttr>() ||
|
||||||
|
tensorTy.getEncoding().isa<SliceEncodingAttr>()) {
|
||||||
auto tensorTy = resType.cast<RankedTensorType>();
|
auto tensorTy = resType.cast<RankedTensorType>();
|
||||||
auto srcType = typeConverter->convertType(elemType);
|
auto srcType = typeConverter->convertType(elemType);
|
||||||
auto llSrc = bitcast(constVal, srcType);
|
auto llSrc = bitcast(constVal, srcType);
|
||||||
|
@@ -533,6 +533,35 @@ public:
|
|||||||
BlockedToMMA(mlir::MLIRContext *context)
|
BlockedToMMA(mlir::MLIRContext *context)
|
||||||
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context) {}
|
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context) {}
|
||||||
|
|
||||||
|
static SmallVector<unsigned, 2>
|
||||||
|
getWarpsPerTile(const ArrayRef<int64_t> &shape, int version, int numWarps) {
|
||||||
|
assert(version == 2);
|
||||||
|
// TODO: Handle one warp per row for fused matmuls
|
||||||
|
// TODO: unsigned -> int64_t to keep things uniform
|
||||||
|
SmallVector<unsigned, 2> ret = {1, 1};
|
||||||
|
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
|
||||||
|
bool changed = false;
|
||||||
|
// TODO (@daadaada): double-check.
|
||||||
|
// original logic in
|
||||||
|
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
|
||||||
|
// seems buggy for shape = [32, 16] ?
|
||||||
|
do {
|
||||||
|
changed = false;
|
||||||
|
if (ret[0] * ret[1] >= numWarps)
|
||||||
|
break;
|
||||||
|
if (shape[0] / shapePerWarp[0] / ret[0] >=
|
||||||
|
shape[1] / (shapePerWarp[1] * 2) / ret[1]) {
|
||||||
|
if (ret[0] < shape[0] / shapePerWarp[0]) {
|
||||||
|
ret[0] *= 2;
|
||||||
|
} else
|
||||||
|
ret[1] *= 2;
|
||||||
|
} else {
|
||||||
|
ret[1] *= 2;
|
||||||
|
}
|
||||||
|
} while (true);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
mlir::LogicalResult
|
mlir::LogicalResult
|
||||||
matchAndRewrite(mlir::Operation *op,
|
matchAndRewrite(mlir::Operation *op,
|
||||||
mlir::PatternRewriter &rewriter) const override {
|
mlir::PatternRewriter &rewriter) const override {
|
||||||
@@ -541,13 +570,20 @@ public:
|
|||||||
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
|
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
|
||||||
if (oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
|
if (oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
|
||||||
return failure();
|
return failure();
|
||||||
// TODO: compute warpsPerCTA
|
// get MMA encoding for the given number of warps
|
||||||
auto newRetType = RankedTensorType::get(
|
auto retShape = oldRetType.getShape();
|
||||||
oldRetType.getShape(), oldRetType.getElementType(),
|
auto mod = op->getParentOfType<mlir::ModuleOp>();
|
||||||
triton::gpu::MmaEncodingAttr::get(oldRetType.getContext(), 2, {2, 2}));
|
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||||
|
auto newRetType =
|
||||||
|
RankedTensorType::get(retShape, oldRetType.getElementType(),
|
||||||
|
triton::gpu::MmaEncodingAttr::get(
|
||||||
|
oldRetType.getContext(), 2,
|
||||||
|
getWarpsPerTile(retShape, 2, numWarps)));
|
||||||
|
// convert accumulator
|
||||||
auto oldAcc = dotOp.getOperand(2);
|
auto oldAcc = dotOp.getOperand(2);
|
||||||
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||||
oldAcc.getLoc(), newRetType, oldAcc);
|
oldAcc.getLoc(), newRetType, oldAcc);
|
||||||
|
// convert output
|
||||||
auto newDot = rewriter.create<triton::DotOp>(
|
auto newDot = rewriter.create<triton::DotOp>(
|
||||||
dotOp.getLoc(), newRetType, dotOp.getOperand(0), dotOp.getOperand(1),
|
dotOp.getLoc(), newRetType, dotOp.getOperand(0), dotOp.getOperand(1),
|
||||||
newAcc, dotOp.allowTF32(), dotOp.transA(), dotOp.transB());
|
newAcc, dotOp.allowTF32(), dotOp.transA(), dotOp.transB());
|
||||||
|
@@ -157,15 +157,6 @@ import triton.language as tl
|
|||||||
@triton.autotune(
|
@triton.autotune(
|
||||||
configs=[
|
configs=[
|
||||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
||||||
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
|
||||||
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
|
||||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
|
||||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
|
||||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
|
||||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
|
||||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
|
||||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
|
||||||
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
|
||||||
],
|
],
|
||||||
key=['M', 'N', 'K'],
|
key=['M', 'N', 'K'],
|
||||||
)
|
)
|
||||||
@@ -318,13 +309,13 @@ else:
|
|||||||
triton.testing.Benchmark(
|
triton.testing.Benchmark(
|
||||||
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
|
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
|
||||||
x_vals=[
|
x_vals=[
|
||||||
128 * i for i in range(2, 33)
|
8192
|
||||||
], # different possible values for `x_name`
|
], # different possible values for `x_name`
|
||||||
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
||||||
# possible values for `line_arg``
|
# possible values for `line_arg``
|
||||||
line_vals=['cublas', 'cublas + relu', 'triton', 'triton + relu'],
|
line_vals=['cublas', 'triton'],
|
||||||
# label name for the lines
|
# label name for the lines
|
||||||
line_names=["cuBLAS", "cuBLAS (+ torch.nn.LeakyReLU)", "Triton", "Triton (+ LeakyReLU)"],
|
line_names=["cuBLAS", "Triton"],
|
||||||
# line styles
|
# line styles
|
||||||
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
|
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
|
||||||
ylabel="TFLOPS", # label name for the y-axis
|
ylabel="TFLOPS", # label name for the y-axis
|
||||||
@@ -336,18 +327,9 @@ def benchmark(M, N, K, provider):
|
|||||||
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
|
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
|
||||||
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
|
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
|
||||||
if provider == 'cublas':
|
if provider == 'cublas':
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b))
|
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), rep=100)
|
||||||
if provider == 'triton':
|
if provider == 'triton':
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b))
|
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), rep=100)
|
||||||
if provider == 'cublas + relu':
|
|
||||||
torch_relu = torch.nn.ReLU(inplace=True)
|
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
|
||||||
lambda: torch_relu(torch.matmul(a, b))
|
|
||||||
)
|
|
||||||
if provider == 'triton + relu':
|
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
|
||||||
lambda: matmul(a, b, activation=leaky_relu)
|
|
||||||
)
|
|
||||||
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
|
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
|
||||||
return perf(ms), perf(max_ms), perf(min_ms)
|
return perf(ms), perf(max_ms), perf(min_ms)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user