From 18c7a72973e4706daab58ba57b8cf88641a37332 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 6 Jan 2023 14:26:06 -0800 Subject: [PATCH] more pass template --- .../Dialect/TritonGPU/Transforms/Passes.h | 2 ++ .../Dialect/TritonGPU/Transforms/Passes.td | 11 ++++++ .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 36 +++++++++---------- .../TritonGPU/Transforms/CMakeLists.txt | 1 + lib/Dialect/TritonGPU/Transforms/Combine.cpp | 4 +-- .../DecomposeConversionsToDotOperand.cpp | 36 +++++++++++++++++++ python/src/triton.cc | 18 ++++++---- python/triton/compiler.py | 2 ++ python/tutorials/06-fused-attention.py | 16 +++++---- 9 files changed, 92 insertions(+), 34 deletions(-) create mode 100644 lib/Dialect/TritonGPU/Transforms/DecomposeConversionsToDotOperand.cpp diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.h b/include/triton/Dialect/TritonGPU/Transforms/Passes.h index 37b2f4e04..f2c800b28 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.h @@ -17,6 +17,8 @@ std::unique_ptr createTritonGPUOptimizeLoadConvertPass(); std::unique_ptr createTritonGPUSinkConversionsFromSharedPass(); +std::unique_ptr createTritonGPUDecomposeConversionsToDotOperandPass(); + std::unique_ptr createTritonGPUCombineOpsPass(int computeCapability = 80); std::unique_ptr createTritonGPUVerifier(); diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index 15983bf7c..f620e6738 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -96,6 +96,17 @@ def TritonGPUSinkConversionsFromShared: Pass<"tritongpu-sink-conversions-from-sh "mlir::triton::TritonDialect"]; } +def TritonGPUDecomposeConversionsToDotOperand: Pass<"tritongpu-decompose-conversions-to-dot-operand", "mlir::ModuleOp"> { + let summary = "Decompose convert[distributed -> dotOperand] into convert[distributed -> shared -> dotOperand]"; + + let description = "Decomposing conversions this way makes it possible to use CSE and re-use #shared tensors"; + + let constructor = "mlir::createTritonGPUDecomposeConversionsToDotOperandPass()"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> { let summary = "canonicalize scf.ForOp ops"; diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 1a601bdd0..ab9fabe2d 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -72,24 +72,24 @@ void storeDistributedToShared(Value src, Value llSrc, Value staIdx1 = i32_val(0); Value stride0 = dstStrides[outOrd[0]]; Value stride1 = dstStrides[outOrd[1]]; - if (auto addOp = dyn_cast(dynIdx0.getDefiningOp())) - if (auto cstRhs = - dyn_cast(addOp.getRhs().getDefiningOp())) { - unsigned rhsVal = - cstRhs.getValue().cast().getValue().getSExtValue(); - unsigned key = (rhsVal / outVec) % maxPhase; - if (cache.find(key) == cache.end()) - cache[key] = dynIdx0; - dynIdx0 = cache[key]; - staIdx0 = - i32_val((rhsVal) / (outVec * maxPhase) * (outVec * maxPhase)); - } - if (auto addOp = dyn_cast(dynIdx1.getDefiningOp())) - if (auto cstRhs = - dyn_cast(addOp.getRhs().getDefiningOp())) { - dynIdx1 = addOp.getLhs(); - staIdx1 = addOp.getRhs(); - } + // if (auto addOp = dyn_cast(dynIdx0.getDefiningOp())) + // if (auto cstRhs = + // dyn_cast(addOp.getRhs().getDefiningOp())) { + // unsigned rhsVal = + // cstRhs.getValue().cast().getValue().getSExtValue(); + // unsigned key = (rhsVal / outVec) % maxPhase; + // if (cache.find(key) == cache.end()) + // cache[key] = dynIdx0; + // dynIdx0 = cache[key]; + // staIdx0 = + // i32_val((rhsVal) / (outVec * maxPhase) * (outVec * maxPhase)); + // } + // if (auto addOp = dyn_cast(dynIdx1.getDefiningOp())) + // if (auto cstRhs = + // dyn_cast(addOp.getRhs().getDefiningOp())) { + // dynIdx1 = addOp.getLhs(); + // staIdx1 = addOp.getRhs(); + // } // offset along non-contiguous dimension Value off1 = mul(dynIdx1, stride1); diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index 84ffe7c42..93e02c998 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -10,6 +10,7 @@ add_mlir_dialect_library(TritonGPUTransforms Prefetch.cpp OptimizeLoadConvert.cpp SinkConversionsFromShared.cpp + DecomposeConversionsToDotOperand.cpp TritonGPUConversion.cpp DEPENDS diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 2b3aa239d..2c3489c76 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -19,7 +19,6 @@ #include - using namespace mlir; namespace { #include "TritonGPUCombine.inc" @@ -483,8 +482,7 @@ public: return op->getBlock() == cvt->getBlock() && !(isa(op) && !op->getResult(0).getType().isa()) && - !isa(op) && - !isa(op); + !isa(op) && !isa(op); }; mlir::getForwardSlice(cvt.getResult(), &cvtSlices, filter); if (cvtSlices.empty()) diff --git a/lib/Dialect/TritonGPU/Transforms/DecomposeConversionsToDotOperand.cpp b/lib/Dialect/TritonGPU/Transforms/DecomposeConversionsToDotOperand.cpp new file mode 100644 index 000000000..494a5af0f --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/DecomposeConversionsToDotOperand.cpp @@ -0,0 +1,36 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +using namespace mlir; + +class TritonGPUDecomposeConversionsToDotOperandPass + : public TritonGPUDecomposeConversionsToDotOperandBase< + TritonGPUDecomposeConversionsToDotOperandPass> { +public: + TritonGPUDecomposeConversionsToDotOperandPass() = default; + + void runOnOperation() override { return; } +}; + +std::unique_ptr +mlir::createTritonGPUDecomposeConversionsToDotOperandPass() { + return std::make_unique(); +} diff --git a/python/src/triton.cc b/python/src/triton.cc index 7257a66fc..bf6f415dd 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1345,14 +1345,18 @@ void init_triton_ir(py::module &&m) { mlir::createTritonGPUCombineOpsPass(computeCapability)); }) .def("add_tritongpu_optimize_load_convert_pass", - [](mlir::PassManager &self) { - self.addPass(mlir::createTritonGPUOptimizeLoadConvertPass()); - }) + [](mlir::PassManager &self) { + self.addPass(mlir::createTritonGPUOptimizeLoadConvertPass()); + }) .def("add_tritongpu_sink_conversions_from_shared_pass", - [](mlir::PassManager &self) { - self.addPass( - mlir::createTritonGPUSinkConversionsFromSharedPass()); - }) + [](mlir::PassManager &self) { + self.addPass(mlir::createTritonGPUSinkConversionsFromSharedPass()); + }) + .def("add_tritongpu_decompose_conversions_to_dot_operand_pass", + [](mlir::PassManager &self) { + self.addPass( + mlir::createTritonGPUDecomposeConversionsToDotOperandPass()); + }) .def("add_triton_gpu_to_llvm", [](mlir::PassManager &self) { self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass()); diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 3da6d3929..e27184736 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -906,6 +906,8 @@ def ttir_to_ttgir(mod, num_warps, num_stages, compute_capability): pm.add_tritongpu_combine_pass(compute_capability) pm.add_cse_pass() # pm.add_tritongpu_optimize_load_convert_pass() + pm.add_tritongpu_sink_conversions_from_shared_pass() + pm.add_tritongpu_decompose_conversions_to_dot_operand_pass() pm.run(mod) return mod diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 0851c148c..79fb70923 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -194,8 +194,10 @@ def _bwd_kernel( # _bwd_kernel = triton.compile("./bwd.ttgir", num_warps=8) # _fwd_kernel = triton.compile("./fails.ptx", num_warps=4, shared=18432) + empty = torch.empty(128, device="cuda") + class _attention(torch.autograd.Function): @staticmethod @@ -255,7 +257,7 @@ class _attention(torch.autograd.Function): do_scaled, delta, BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL, ) - + # _bwd_kernel[(ctx.grid[1],1,1)]( # q.data_ptr(), k.data_ptr(), v.data_ptr(), ctx.sm_scale, # o.data_ptr(), do_scaled.data_ptr(), @@ -284,8 +286,8 @@ class _attention(torch.autograd.Function): BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8, num_stages=1, ) - print(pgm.asm["ttgir"]) - exit(1) + # print(pgm.asm["ttgir"]) + # exit() return dq, dk, dv, None @@ -327,6 +329,7 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16): triton.testing.assert_almost_equal(ref_dk, tri_dk) triton.testing.assert_almost_equal(ref_dq, tri_dq) + BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 # vary seq length for fixed head and batch=4 configs = [triton.testing.Benchmark( @@ -358,8 +361,8 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f do = torch.randn_like(o) fn = lambda: o.backward(do, retain_graph=True) ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep) - flops_per_matmul = 2.*BATCH*H*N_CTX*N_CTX*D_HEAD*0.5 - total_flops = 2*flops_per_matmul + flops_per_matmul = 2. * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5 + total_flops = 2 * flops_per_matmul # print(total_flops/ms*1e-9) print(ms) return ms @@ -376,4 +379,5 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep) return ms -bench_flash_attention.run(save_path='.', print_data=True) \ No newline at end of file + +# bench_flash_attention.run(save_path='.', print_data=True)