From 54ae3e8d6edb51dc514251b5571759c70a2d917b Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Wed, 28 Dec 2022 13:42:43 -0800 Subject: [PATCH] cleanup --- .../Dialect/TritonGPU/Transforms/Passes.h | 2 + .../Dialect/TritonGPU/Transforms/Passes.td | 13 +++++ .../TritonGPU/Transforms/CMakeLists.txt | 1 + lib/Dialect/TritonGPU/Transforms/Combine.cpp | 56 ------------------- python/src/triton.cc | 6 +- python/triton/compiler.py | 7 ++- python/tutorials/06-fused-attention.py | 44 +++++++-------- 7 files changed, 47 insertions(+), 82 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.h b/include/triton/Dialect/TritonGPU/Transforms/Passes.h index 7e02fb2b9..eb375c4a2 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.h @@ -13,6 +13,8 @@ std::unique_ptr createTritonGPUCanonicalizeLoopsPass(); std::unique_ptr createTritonGPUCoalescePass(); +std::unique_ptr createTritonGPUOptimizeLoadConvertPass(); + std::unique_ptr createTritonGPUCombineOpsPass(int computeCapability = 80); std::unique_ptr createTritonGPUVerifier(); diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index f22a76c55..c23e2556f 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -71,6 +71,19 @@ def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> { ]; } +def TritonGPUOptimizeLoadConvert: Pass<"tritongpu-optimize-load-convert", "mlir::ModuleOp"> { + let summary = "Optimize load + convert into insert_slice_async + wait + extract_slice + convert"; + + let description = "Transform load + convert into insert_slice_async + wait + extract_slice + convert." + "This decreases registers pressure on architecture with direct pathways between DRAM " + "and shared memory"; + + let constructor = "mlir::createTritonGPUOptimizeLoadConvertPass()"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> { let summary = "canonicalize scf.ForOp ops"; diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index aabcc1901..f297c5e87 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -8,6 +8,7 @@ add_mlir_dialect_library(TritonGPUTransforms Combine.cpp Pipeline.cpp Prefetch.cpp + OptimizeLoadConvert.cpp TritonGPUConversion.cpp DEPENDS diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index c3c397aeb..dd0299323 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -19,7 +19,6 @@ #include -#define int_attr(num) rewriter.getI64IntegerAttr(num) using namespace mlir; namespace { @@ -1155,60 +1154,6 @@ public: } }; -class LoadConvertToInsertSlice : public mlir::RewritePattern{ - -public: - explicit LoadConvertToInsertSlice(mlir::MLIRContext *context) - : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 2, context) {} - - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto cvt = cast(op); - auto origRetType = cvt.getResult().getType().cast(); - auto shape = origRetType.getShape(); - auto eltType = origRetType.getElementType(); - auto dotOpEncoding = origRetType.getEncoding().dyn_cast(); - if(!dotOpEncoding){ - return failure(); - } - auto loadOp = dyn_cast(*cvt.getOperand().getDefiningOp()); - if(!loadOp){ - return failure(); - } - auto blockedEncoding = loadOp.getType().cast().getEncoding().dyn_cast(); - if(!blockedEncoding) - return failure(); - auto sharedEncoding = triton::gpu::SharedEncodingAttr::get(getContext(), dotOpEncoding, shape, - blockedEncoding.getOrder(), eltType); - auto srcTy = RankedTensorType::get({1, shape[0], shape[1]}, - eltType, - sharedEncoding); - auto loadTensor = rewriter.create(op->getLoc(), srcTy); - - auto newOp = rewriter.create( - op->getLoc(), loadTensor.getType(), - loadOp.ptr(), - loadTensor, rewriter.create(op->getLoc(), 0, 32), - loadOp.mask(), - loadOp.other(), loadOp.cache(), - loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0); - - rewriter.create(op->getLoc(), 0); - auto tmpType = RankedTensorType::get({shape[0], shape[1]}, eltType, sharedEncoding); - auto tmp = rewriter.create(op->getLoc(), tmpType, newOp, - SmallVector{int_attr(0), int_attr(0), int_attr(0)}, - SmallVector{int_attr(1), - int_attr(shape[0]), - int_attr(shape[1])}, - SmallVector{int_attr(1), int_attr(1), int_attr(1)}); - rewriter.replaceOpWithNewOp(op, origRetType, tmp); - return success(); - - } - -}; - class FixupLoop : public mlir::RewritePattern { public: @@ -1280,7 +1225,6 @@ public: patterns.add(context); patterns.add(context); patterns.add(context, computeCapability); - patterns.add(context); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { signalPassFailure(); diff --git a/python/src/triton.cc b/python/src/triton.cc index 31f754add..a8ad58a21 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1339,11 +1339,15 @@ void init_triton_ir(py::module &&m) { [](mlir::PassManager &self) { self.addPass(mlir::createTritonGPUPrefetchPass()); }) - .def("add_triton_gpu_combine_pass", + .def("add_tritongpu_combine_pass", [](mlir::PassManager &self, int computeCapability) { self.addPass( mlir::createTritonGPUCombineOpsPass(computeCapability)); }) + .def("add_tritongpu_optimize_load_convert_pass", + [](mlir::PassManager &self) { + self.addPass(mlir::createTritonGPUOptimizeLoadConvertPass()); + }) .def("add_triton_gpu_to_llvm", [](mlir::PassManager &self) { self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass()); diff --git a/python/triton/compiler.py b/python/triton/compiler.py index f4e3fd3e0..30e9c059a 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -894,17 +894,18 @@ def ttir_to_ttgir(mod, num_warps, num_stages, compute_capability): pm.add_coalesce_pass() # The combine pass converts blocked layout to mma layout # for dot ops so that pipeline can get shared memory swizzled correctly. - pm.add_triton_gpu_combine_pass(compute_capability) + pm.add_tritongpu_combine_pass(compute_capability) pm.add_tritongpu_pipeline_pass(num_stages) # Prefetch must be done after pipeline pass because pipeline pass # extracts slices from the original tensor. pm.add_tritongpu_prefetch_pass() pm.add_canonicalizer_pass() pm.add_cse_pass() - pm.add_triton_gpu_combine_pass(compute_capability) + pm.add_tritongpu_combine_pass(compute_capability) pm.add_licm_pass() - pm.add_triton_gpu_combine_pass(compute_capability) + pm.add_tritongpu_combine_pass(compute_capability) pm.add_cse_pass() + pm.add_tritongpu_optimize_load_convert_pass() pm.run(mod) return mod diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index d10932959..bf3405335 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -191,7 +191,7 @@ def _bwd_kernel( tl.store(dv_ptrs, dv) tl.store(dk_ptrs, dk) -_fwd_kernel = triton.compile("./flash-attention.ttgir", num_warps=4) +# _fwd_kernel = triton.compile("./flash-attention.ttgir", num_warps=4) empty = torch.empty(128, device="cuda") @@ -210,28 +210,28 @@ class _attention(torch.autograd.Function): m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) num_warps = 4 if Lk <= 64 else 8 - # _fwd_kernel[grid]( - # q, k, v, sm_scale, - # L, m, - # o, - # q.stride(0), q.stride(1), q.stride(2), q.stride(3), - # k.stride(0), k.stride(1), k.stride(2), k.stride(3), - # v.stride(0), v.stride(1), v.stride(2), v.stride(3), - # o.stride(0), o.stride(1), o.stride(2), o.stride(3), - # q.shape[0], q.shape[1], q.shape[2], - # BLOCK_M=BLOCK, BLOCK_N=BLOCK, - # BLOCK_DMODEL=Lk, num_warps=num_warps, - # num_stages=1, - # ) _fwd_kernel[grid]( - q.data_ptr(), k.data_ptr(), v.data_ptr(), sm_scale, - L.data_ptr(), m.data_ptr(), - o.data_ptr(), - q.stride(0), q.stride(1), q.stride(2), - k.stride(0), k.stride(1), k.stride(2), - v.stride(0), v.stride(1), v.stride(2), - o.stride(0), o.stride(1), o.stride(2), - q.shape[0], q.shape[1], q.shape[2]) + q, k, v, sm_scale, + L, m, + o, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q.shape[0], q.shape[1], q.shape[2], + BLOCK_M=BLOCK, BLOCK_N=BLOCK, + BLOCK_DMODEL=Lk, num_warps=num_warps, + num_stages=1, + ) + # _fwd_kernel[grid]( + # q.data_ptr(), k.data_ptr(), v.data_ptr(), sm_scale, + # L.data_ptr(), m.data_ptr(), + # o.data_ptr(), + # q.stride(0), q.stride(1), q.stride(2), + # k.stride(0), k.stride(1), k.stride(2), + # v.stride(0), v.stride(1), v.stride(2), + # o.stride(0), o.stride(1), o.stride(2), + # q.shape[0], q.shape[1], q.shape[2]) ctx.save_for_backward(q, k, v, o, L, m) ctx.BLOCK = BLOCK