diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index d671f377d..c13fcde86 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -104,6 +104,7 @@ SmallVector getSizePerThread(const Attribute &layout) { return SmallVector(blockedLayout.getSizePerThread().begin(), blockedLayout.getSizePerThread().end()); } else if (auto sliceLayout = layout.dyn_cast()) { + return {1}; return getSizePerThread(sliceLayout.getParent()); } else if (auto mmaLayout = layout.dyn_cast()) { if (mmaLayout.isAmpere()) { diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 7dcdc0162..c3c397aeb 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -19,6 +19,8 @@ #include +#define int_attr(num) rewriter.getI64IntegerAttr(num) + using namespace mlir; namespace { #include "TritonGPUCombine.inc" @@ -1153,6 +1155,60 @@ public: } }; +class LoadConvertToInsertSlice : public mlir::RewritePattern{ + +public: + explicit LoadConvertToInsertSlice(mlir::MLIRContext *context) + : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 2, context) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto cvt = cast(op); + auto origRetType = cvt.getResult().getType().cast(); + auto shape = origRetType.getShape(); + auto eltType = origRetType.getElementType(); + auto dotOpEncoding = origRetType.getEncoding().dyn_cast(); + if(!dotOpEncoding){ + return failure(); + } + auto loadOp = dyn_cast(*cvt.getOperand().getDefiningOp()); + if(!loadOp){ + return failure(); + } + auto blockedEncoding = loadOp.getType().cast().getEncoding().dyn_cast(); + if(!blockedEncoding) + return failure(); + auto sharedEncoding = triton::gpu::SharedEncodingAttr::get(getContext(), dotOpEncoding, shape, + blockedEncoding.getOrder(), eltType); + auto srcTy = RankedTensorType::get({1, shape[0], shape[1]}, + eltType, + sharedEncoding); + auto loadTensor = rewriter.create(op->getLoc(), srcTy); + + auto newOp = rewriter.create( + op->getLoc(), loadTensor.getType(), + loadOp.ptr(), + loadTensor, rewriter.create(op->getLoc(), 0, 32), + loadOp.mask(), + loadOp.other(), loadOp.cache(), + loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0); + + rewriter.create(op->getLoc(), 0); + auto tmpType = RankedTensorType::get({shape[0], shape[1]}, eltType, sharedEncoding); + auto tmp = rewriter.create(op->getLoc(), tmpType, newOp, + SmallVector{int_attr(0), int_attr(0), int_attr(0)}, + SmallVector{int_attr(1), + int_attr(shape[0]), + int_attr(shape[1])}, + SmallVector{int_attr(1), int_attr(1), int_attr(1)}); + rewriter.replaceOpWithNewOp(op, origRetType, tmp); + return success(); + + } + +}; + class FixupLoop : public mlir::RewritePattern { public: @@ -1224,6 +1280,7 @@ public: patterns.add(context); patterns.add(context); patterns.add(context, computeCapability); + patterns.add(context); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { signalPassFailure(); diff --git a/python/examples/copy_strided.py b/python/examples/copy_strided.py deleted file mode 100644 index 922c5ba5c..000000000 --- a/python/examples/copy_strided.py +++ /dev/null @@ -1,19 +0,0 @@ - -import triton -import triton.language as tl - - -# triton kernel -@triton.jit -def kernel(X, stride_xm, - Z, stride_zn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): - off_m = tl.arange(0, BLOCK_M) - off_n = tl.arange(0, BLOCK_N) - Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * 1 - Zs = Z + off_m[:, None] * 1 + off_n[None, :] * stride_zn - tl.store(Zs, tl.load(Xs)) - - -ret = triton.compile(kernel, "*fp32,i32,*fp32,i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}, output="ttgir") -print(ret) diff --git a/python/examples/empty.py b/python/examples/empty.py deleted file mode 100644 index df313fb85..000000000 --- a/python/examples/empty.py +++ /dev/null @@ -1,13 +0,0 @@ -import torch - -import triton -import triton.language as tl - - -@triton.jit -def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr): - pass - - -X = torch.randn(1, device="cuda") -pgm = kernel[(1,)](X, 1, 1, BLOCK=1024) diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 99e4e5928..f4e3fd3e0 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -1562,7 +1562,7 @@ class CompiledKernel: if self.shared > max_shared: raise OutOfResources(self.shared, max_shared, "shared memory") mod, func, n_regs, n_spills = cuda_utils.load_binary(self.metadata["name"], self.asm["cubin"], self.shared, device) - print(n_regs, n_spills) + print(self.shared, n_regs, n_spills) self.cu_module = mod self.cu_function = func diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 1bd787aaa..fa8562166 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -194,6 +194,7 @@ def _bwd_kernel( empty = torch.empty(128, device="cuda") + class _attention(torch.autograd.Function): @staticmethod @@ -220,7 +221,7 @@ class _attention(torch.autograd.Function): q.shape[0], q.shape[1], q.shape[2], BLOCK_M=BLOCK, BLOCK_N=BLOCK, BLOCK_DMODEL=Lk, num_warps=num_warps, - num_stages=2, + num_stages=1, ) ctx.save_for_backward(q, k, v, o, L, m) @@ -335,7 +336,8 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep) flops_per_matmul = 2.*BATCH*H*N_CTX*N_CTX*D_HEAD*0.5 total_flops = 2*flops_per_matmul - print(total_flops/ms*1e-9) + # print(total_flops/ms*1e-9) + print(ms) return ms if provider == "flash": lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)