diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.h b/include/triton/Dialect/TritonGPU/Transforms/Passes.h index eb375c4a2..37b2f4e04 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.h @@ -15,6 +15,8 @@ std::unique_ptr createTritonGPUCoalescePass(); std::unique_ptr createTritonGPUOptimizeLoadConvertPass(); +std::unique_ptr createTritonGPUSinkConversionsFromSharedPass(); + std::unique_ptr createTritonGPUCombineOpsPass(int computeCapability = 80); std::unique_ptr createTritonGPUVerifier(); diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index c23e2556f..15983bf7c 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -84,6 +84,18 @@ def TritonGPUOptimizeLoadConvert: Pass<"tritongpu-optimize-load-convert", "mlir: "mlir::triton::TritonDialect"]; } +def TritonGPUSinkConversionsFromShared: Pass<"tritongpu-sink-conversions-from-shared", "mlir::ModuleOp"> { + let summary = "Sink conversions from shared into loops"; + + let description = "This pass sinks conversions from shared memory into loops. This will lead the codegen " + "to keep data in shared memory throughout loops, which will reduce register pressure."; + + let constructor = "mlir::createTritonGPUSinkConversionsFromSharedPass()"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> { let summary = "canonicalize scf.ForOp ops"; diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index f297c5e87..84ffe7c42 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -9,6 +9,7 @@ add_mlir_dialect_library(TritonGPUTransforms Pipeline.cpp Prefetch.cpp OptimizeLoadConvert.cpp + SinkConversionsFromShared.cpp TritonGPUConversion.cpp DEPENDS diff --git a/python/src/triton.cc b/python/src/triton.cc index a8ad58a21..7257a66fc 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1348,6 +1348,11 @@ void init_triton_ir(py::module &&m) { [](mlir::PassManager &self) { self.addPass(mlir::createTritonGPUOptimizeLoadConvertPass()); }) + .def("add_tritongpu_sink_conversions_from_shared_pass", + [](mlir::PassManager &self) { + self.addPass( + mlir::createTritonGPUSinkConversionsFromSharedPass()); + }) .def("add_triton_gpu_to_llvm", [](mlir::PassManager &self) { self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass()); diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 2da77fbc6..0851c148c 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -191,7 +191,7 @@ def _bwd_kernel( tl.store(dv_ptrs, dv) tl.store(dk_ptrs, dk) -_bwd_kernel = triton.compile("./bwd.ttgir", num_warps=8) +# _bwd_kernel = triton.compile("./bwd.ttgir", num_warps=8) # _fwd_kernel = triton.compile("./fails.ptx", num_warps=4, shared=18432) empty = torch.empty(128, device="cuda") @@ -256,36 +256,36 @@ class _attention(torch.autograd.Function): BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL, ) - _bwd_kernel[(ctx.grid[1],1,1)]( - q.data_ptr(), k.data_ptr(), v.data_ptr(), ctx.sm_scale, - o.data_ptr(), do_scaled.data_ptr(), - dq.data_ptr(), dk.data_ptr(), dv.data_ptr(), - l.data_ptr(), m.data_ptr(), - delta.data_ptr(), - q.stride(0), q.stride(1), q.stride(2), - k.stride(0), k.stride(1), k.stride(2), - v.stride(0), v.stride(1), v.stride(2), - q.shape[0], q.shape[1], q.shape[2], - ctx.grid[0] - ) - - # pgm = _bwd_kernel[(ctx.grid[1],)]( - # q, k, v, ctx.sm_scale, - # o, do_scaled, - # dq, dk, dv, - # l, m, - # delta, - # q.stride(0), q.stride(1), q.stride(2), q.stride(3), - # k.stride(0), k.stride(1), k.stride(2), k.stride(3), - # v.stride(0), v.stride(1), v.stride(2), v.stride(3), + # _bwd_kernel[(ctx.grid[1],1,1)]( + # q.data_ptr(), k.data_ptr(), v.data_ptr(), ctx.sm_scale, + # o.data_ptr(), do_scaled.data_ptr(), + # dq.data_ptr(), dk.data_ptr(), dv.data_ptr(), + # l.data_ptr(), m.data_ptr(), + # delta.data_ptr(), + # q.stride(0), q.stride(1), q.stride(2), + # k.stride(0), k.stride(1), k.stride(2), + # v.stride(0), v.stride(1), v.stride(2), # q.shape[0], q.shape[1], q.shape[2], - # ctx.grid[0], - # BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK, - # BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8, - # num_stages=1, + # ctx.grid[0] # ) - # print(pgm.asm["ttgir"]) - # # exit(1) + + pgm = _bwd_kernel[(ctx.grid[1],)]( + q, k, v, ctx.sm_scale, + o, do_scaled, + dq, dk, dv, + l, m, + delta, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + q.shape[0], q.shape[1], q.shape[2], + ctx.grid[0], + BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8, + num_stages=1, + ) + print(pgm.asm["ttgir"]) + exit(1) return dq, dk, dv, None