diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 3c2dca2e6..5391574dd 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -159,6 +159,16 @@ ChangeResult AxisInfoAnalysis::visitOperation( curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(), newContiguity, newDivisibility, newConstancy); } + // TODO: All other binary ops + if (llvm::isa(op)) { + auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; }; + auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; }; + auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d) { + return gcd(lhs.getConstancy(d), rhs.getConstancy(d)); + }; + curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(), + newContiguity, newDivisibility, newConstancy); + } // Splat if (llvm::isa(op)) { Type _retTy = *op->result_type_begin(); @@ -200,7 +210,8 @@ ChangeResult AxisInfoAnalysis::visitOperation( for (int d = 0; d < retTy.getRank(); ++d) { contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d)); divisibility.push_back(opInfo.getDivisibility(d)); - constancy.push_back(opShape[d] == 1 ? retShape[d] : 1); + constancy.push_back(opShape[d] == 1 ? retShape[d] + : opInfo.getConstancy(d)); } curr = AxisInfo(contiguity, divisibility, constancy); } diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 44caad4e5..93ad97f80 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -693,7 +693,8 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal, TypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc) { auto tensorTy = resType.cast(); - if (tensorTy.getEncoding().isa()) { + if (tensorTy.getEncoding().isa() || + tensorTy.getEncoding().isa()) { auto tensorTy = resType.cast(); auto srcType = typeConverter->convertType(elemType); auto llSrc = bitcast(constVal, srcType); diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 36ea7030f..456ce1200 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -533,6 +533,35 @@ public: BlockedToMMA(mlir::MLIRContext *context) : mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context) {} + static SmallVector + getWarpsPerTile(const ArrayRef &shape, int version, int numWarps) { + assert(version == 2); + // TODO: Handle one warp per row for fused matmuls + // TODO: unsigned -> int64_t to keep things uniform + SmallVector ret = {1, 1}; + SmallVector shapePerWarp = {16, 8}; + bool changed = false; + // TODO (@daadaada): double-check. + // original logic in + // https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252 + // seems buggy for shape = [32, 16] ? + do { + changed = false; + if (ret[0] * ret[1] >= numWarps) + break; + if (shape[0] / shapePerWarp[0] / ret[0] >= + shape[1] / (shapePerWarp[1] * 2) / ret[1]) { + if (ret[0] < shape[0] / shapePerWarp[0]) { + ret[0] *= 2; + } else + ret[1] *= 2; + } else { + ret[1] *= 2; + } + } while (true); + return ret; + } + mlir::LogicalResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { @@ -541,13 +570,20 @@ public: auto oldRetType = dotOp.getResult().getType().cast(); if (oldRetType.getEncoding().isa()) return failure(); - // TODO: compute warpsPerCTA - auto newRetType = RankedTensorType::get( - oldRetType.getShape(), oldRetType.getElementType(), - triton::gpu::MmaEncodingAttr::get(oldRetType.getContext(), 2, {2, 2})); + // get MMA encoding for the given number of warps + auto retShape = oldRetType.getShape(); + auto mod = op->getParentOfType(); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + auto newRetType = + RankedTensorType::get(retShape, oldRetType.getElementType(), + triton::gpu::MmaEncodingAttr::get( + oldRetType.getContext(), 2, + getWarpsPerTile(retShape, 2, numWarps))); + // convert accumulator auto oldAcc = dotOp.getOperand(2); auto newAcc = rewriter.create( oldAcc.getLoc(), newRetType, oldAcc); + // convert output auto newDot = rewriter.create( dotOp.getLoc(), newRetType, dotOp.getOperand(0), dotOp.getOperand(1), newAcc, dotOp.allowTF32(), dotOp.transA(), dotOp.transB()); diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index f685a6059..82a264dc2 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -157,15 +157,6 @@ import triton.language as tl @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), ], key=['M', 'N', 'K'], ) @@ -318,13 +309,13 @@ else: triton.testing.Benchmark( x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot x_vals=[ - 128 * i for i in range(2, 33) + 8192 ], # different possible values for `x_name` line_arg='provider', # argument name whose value corresponds to a different line in the plot # possible values for `line_arg`` - line_vals=['cublas', 'cublas + relu', 'triton', 'triton + relu'], + line_vals=['cublas', 'triton'], # label name for the lines - line_names=["cuBLAS", "cuBLAS (+ torch.nn.LeakyReLU)", "Triton", "Triton (+ LeakyReLU)"], + line_names=["cuBLAS", "Triton"], # line styles styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], ylabel="TFLOPS", # label name for the y-axis @@ -336,18 +327,9 @@ def benchmark(M, N, K, provider): a = torch.randn((M, K), device='cuda', dtype=torch.float16) b = torch.randn((K, N), device='cuda', dtype=torch.float16) if provider == 'cublas': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b)) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), rep=100) if provider == 'triton': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b)) - if provider == 'cublas + relu': - torch_relu = torch.nn.ReLU(inplace=True) - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: torch_relu(torch.matmul(a, b)) - ) - if provider == 'triton + relu': - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: matmul(a, b, activation=leaky_relu) - ) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), rep=100) perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) return perf(ms), perf(max_ms), perf(min_ms)