From 16e973edf2a92f792ff84c6fdf5a50ceeab3bf9f Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Tue, 6 Dec 2022 09:08:55 -0800 Subject: [PATCH 1/3] [BACKEND] Fix dependency analysis in pipeline (#946) --- lib/Dialect/TritonGPU/Transforms/Pipeline.cpp | 30 ++++++++++++------- python/tutorials/05-layer-norm.py | 2 +- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index b2292d901..dd5aa9d2a 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -123,9 +123,13 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet &deps) { return; if (auto arg = v.dyn_cast()) { - deps.insert(v); - // Note: we have iv as the first arg, so the op idx is arg.getArgNumber()-1 - collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages - 1, deps); + if (arg.getArgNumber() > 0) { + // Skip the first arg (loop induction variable) + // Otherwise the op idx is arg.getArgNumber()-1 + deps.insert(v); + collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages - 1, + deps); + } } else { // value // v might be in deps, but we still need to visit v. // This is because v might depend on value in previous iterations @@ -376,11 +380,11 @@ scf::ForOp LoopPipeliner::createNewForOp() { OpBuilder builder(forOp); // Order of new args: - // (original args), - // (insertSliceAsync buffer at stage numStages - 1) for each load - // (extracted tensor) for each load - // (depArgs at stage numStages-1) - // (iv at stage numStages-1) + // (original args) + // (insertSliceAsync buffer at stage numStages - 1) for each load + // (extracted tensor) for each load + // (depArgs at stage numStages - 1) + // (iv at stage numStages - 2) // (pipeline iteration index) // (loop iteration index) SmallVector newLoopArgs; @@ -421,6 +425,7 @@ scf::ForOp LoopPipeliner::createNewForOp() { BlockAndValueMapping mapping; for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); + mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); // 2.1 clone the loop body, replace original args with args of the new ForOp // Insert async wait if necessary. @@ -469,6 +474,7 @@ scf::ForOp LoopPipeliner::createNewForOp() { Value nextLoopCond = builder.create(nextIV.getLoc(), arith::CmpIPredicate::slt, nextIV, newForOp.getUpperBound()); + nextMapping.map(forOp.getInductionVar(), nextIV); // Slice index SmallVector nextBuffers; @@ -598,9 +604,11 @@ scf::ForOp LoopPipeliner::createNewForOp() { for (Value nextSlice : extractSlices) yieldValues.push_back(nextSlice); - for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i) - yieldValues.push_back( - depArgsMapping.lookup(newForOp.getRegionIterArgs()[i])); + for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i) { + auto arg = newForOp.getRegionIterArgs()[i]; + assert(depArgsMapping.count(arg) && "Missing loop-carried value"); + yieldValues.push_back(depArgsMapping[arg]); + } yieldValues.push_back(nextIV); yieldValues.push_back(pipelineIterIdx); yieldValues.push_back(loopIterIdx); diff --git a/python/tutorials/05-layer-norm.py b/python/tutorials/05-layer-norm.py index 1cefc60b9..110351af5 100644 --- a/python/tutorials/05-layer-norm.py +++ b/python/tutorials/05-layer-norm.py @@ -257,5 +257,5 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='c grad_to_none=[x], rep=500) return gbps(ms), gbps(max_ms), gbps(min_ms) - +# test_layer_norm(1151, 8192, torch.float16) bench_layer_norm.run(save_path='.', print_data=True) From 532e10cf87906060809470984e407c665529f6db Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 6 Dec 2022 09:32:13 -0800 Subject: [PATCH 2/3] [FRONTEND][BACKEND] Clean-up transpositions (#953) --- include/triton/Dialect/Triton/IR/TritonOps.td | 2 +- lib/Conversion/TritonGPUToLLVM/DotHelpers.h | 9 +-------- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 14 +------------- .../TritonToTritonGPU/TritonToTritonGPU.cpp | 5 ++--- lib/Dialect/Triton/Transforms/Combine.td | 16 ++++++++-------- lib/Dialect/TritonGPU/Transforms/Combine.cpp | 3 +-- python/src/triton.cc | 5 ++--- python/tests/test_core.py | 4 ++-- python/triton/compiler.py | 3 +++ python/triton/language/core.py | 11 +++++------ python/triton/language/semantic.py | 8 +++----- test/Triton/combine.mlir | 4 ++-- 12 files changed, 31 insertions(+), 53 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 09490952a..258cd41b2 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -339,7 +339,7 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect, $d = matrix_multiply($a, $b) + $c }]; - let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32, BoolAttr:$transA, BoolAttr:$transB); + let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32); let results = (outs TT_FpIntTensor:$d); diff --git a/lib/Conversion/TritonGPUToLLVM/DotHelpers.h b/lib/Conversion/TritonGPUToLLVM/DotHelpers.h index e232d6fa8..3943bc1b8 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotHelpers.h +++ b/lib/Conversion/TritonGPUToLLVM/DotHelpers.h @@ -942,11 +942,6 @@ struct MMA16816ConversionHelper { SmallVector shape(aTensorTy.getShape().begin(), aTensorTy.getShape().end()); - // TODO[Superjomn]: transA cannot be accessed in ConvertLayoutOp. - bool transA = false; - if (transA) { - std::swap(shape[0], shape[1]); - } ValueTable ha; std::function loadFn; @@ -1052,8 +1047,6 @@ struct MMA16816ConversionHelper { SmallVector aShape(aTensorTy.getShape().begin(), aTensorTy.getShape().end()); - if (op.transA()) - std::swap(aShape[0], aShape[1]); auto dShape = dTensorTy.getShape(); @@ -1462,8 +1455,8 @@ Value DotOpMmaV1ConversionHelper::loadB( SmallVector order(sharedLayout.getOrder().begin(), sharedLayout.getOrder().end()); - Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter); + Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter); bool isBRow = order[0] != 0; bool isBVec4 = isBRow && shape[order[0]] <= 16; // TODO[Superjomn]: Support the case when isBVec4=false later diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 2bfbbb090..33fd3d932 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -3253,7 +3253,7 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { // Here we assume the DotOp's operands always comes from shared memory. auto AShape = A.getType().cast().getShape(); - size_t reduceAxis = op.transA() ? 0 : 1; + size_t reduceAxis = 1; unsigned K = AShape[reduceAxis]; bool isOuter = K == 1; @@ -3492,22 +3492,10 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor, auto DShape = DTensorTy.getShape(); auto wpt = mmaLayout.getWarpsPerCTA(); - bool transA = op.transA(); - bool transB = op.transB(); - // TODO[Superjomn]: order cannot accessed in DotOp. SmallVector AOrder({1, 0}); SmallVector BOrder({1, 0}); - if (transA) { - std::swap(AShape[0], AShape[1]); - std::swap(AOrder[0], AOrder[1]); - } - if (transB) { - std::swap(BShape[0], BShape[1]); - std::swap(BOrder[0], BOrder[0]); - } - bool isARow = AOrder[0] != 0; bool isBRow = BOrder[0] != 0; bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 439c3f771..e4a3e8064 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -245,9 +245,8 @@ struct TritonDotPattern : public OpConversionPattern { bType.getElementType(), encoding); b = rewriter.create(b.getLoc(), dstType, b); } - rewriter.replaceOpWithNewOp( - op, retType, a, b, adaptor.c(), adaptor.allowTF32(), adaptor.transA(), - adaptor.transB()); + rewriter.replaceOpWithNewOp(op, retType, a, b, adaptor.c(), + adaptor.allowTF32()); return success(); } }; diff --git a/lib/Dialect/Triton/Transforms/Combine.td b/lib/Dialect/Triton/Transforms/Combine.td index b8caf93ce..1b84cc7f3 100644 --- a/lib/Dialect/Triton/Transforms/Combine.td +++ b/lib/Dialect/Triton/Transforms/Combine.td @@ -12,21 +12,21 @@ include "triton/Dialect/Triton/IR/TritonOps.td" // AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d) // AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d) def CombineDotAddIPattern : Pat< - (Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB)), - (TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB), + (Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)), + (TT_DotOp $a, $b, $d, $allowTF32), [(Constraint> $c)]>; def CombineDotAddFPattern : Pat< - (Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB)), - (TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB), + (Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)), + (TT_DotOp $a, $b, $d, $allowTF32), [(Constraint> $c)]>; def CombineDotAddIRevPattern : Pat< - (Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB), $d), - (TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB), + (Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d), + (TT_DotOp $a, $b, $d, $allowTF32), [(Constraint> $c)]>; def CombineDotAddFRevPattern : Pat< - (Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB), $d), - (TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB), + (Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d), + (TT_DotOp $a, $b, $d, $allowTF32), [(Constraint> $c)]>; diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 2366fe962..e14bae003 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -781,8 +781,7 @@ public: a = rewriter.create(a.getLoc(), newAType, a); b = rewriter.create(b.getLoc(), newBType, b); auto newDot = rewriter.create( - dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.allowTF32(), - dotOp.transA(), dotOp.transB()); + dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.allowTF32()); rewriter.replaceOpWithNewOp( op, oldRetType, newDot.getResult()); diff --git a/python/src/triton.cc b/python/src/triton.cc index 95ba9409f..9d0a7f0d1 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1184,11 +1184,10 @@ void init_triton_ir(py::module &&m) { }) .def("create_dot", [](mlir::OpBuilder &self, mlir::Value &a, mlir::Value &b, - mlir::Value &c, bool allowTF32, bool transA, - bool transB) -> mlir::Value { + mlir::Value &c, bool allowTF32) -> mlir::Value { auto loc = self.getUnknownLoc(); return self.create(loc, c.getType(), a, b, c, - allowTF32, transA, transB); + allowTF32); }) .def("create_exp", [](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value { diff --git a/python/tests/test_core.py b/python/tests/test_core.py index a2175469c..68f78272b 100644 --- a/python/tests/test_core.py +++ b/python/tests/test_core.py @@ -1126,7 +1126,7 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'): if CHAIN_DOT: # tl.store(Zs, z) # tl.debug_barrier() - z = tl.dot(tl.trans(z.to(tl.float16)), tl.load(Ws)) + z = tl.dot(z.to(tl.float16), tl.load(Ws)) tl.store(Zs, z) # input rs = RandomState(17) @@ -1173,7 +1173,7 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'): denom = np.sum(num, axis=-1, keepdims=True) z_ref = num / denom if epilogue == 'chain-dot': - z_ref = np.matmul(z_ref.T, w) + z_ref = np.matmul(z_ref, w) # compare # print(z_ref[:,0], z_tri[:,0]) np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 86b655f21..fbc85eb45 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -762,6 +762,9 @@ class CodeGenerator(ast.NodeVisitor): def visit_Attribute(self, node): lhs = self.visit(node.value) + if isinstance(lhs, triton.language.tensor): + if node.attr == "T": + return triton.language.semantic.trans(lhs, builder=self.builder) return getattr(lhs, node.attr) def visit_Expr(self, node): diff --git a/python/triton/language/core.py b/python/triton/language/core.py index d208cc45e..97f313de9 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -603,10 +603,9 @@ class tensor: assert False, "unsupported" return ret - - # x[:, None, :, None] - # x = expand_dims(x, axis=1) - # x = expand_dims(x, axis=2) + @property + def T(self): + assert False, "Transposition must be created by the AST Visitor" @builtin def to(self, dtype, bitcast=False, _builder=None): @@ -766,7 +765,7 @@ def view(input, shape, _builder=None): @builtin -def dot(input, other, allow_tf32=True, trans_a=False, trans_b=False, _builder=None): +def dot(input, other, allow_tf32=True, _builder=None): """ Returns the matrix product of two blocks. @@ -778,7 +777,7 @@ def dot(input, other, allow_tf32=True, trans_a=False, trans_b=False, _builder=No :type other: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} """ allow_tf32 = _constexpr_to_value(allow_tf32) - return semantic.dot(input, other, allow_tf32, trans_a, trans_b, _builder) + return semantic.dot(input, other, allow_tf32, _builder) # ----------------------- diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 6fc9b33b4..104d5ac54 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -976,8 +976,6 @@ def atomic_xchg(ptr: tl.tensor, def dot(lhs: tl.tensor, rhs: tl.tensor, allow_tf32: bool, - trans_a: bool, - trans_b: bool, builder: ir.builder) -> tl.tensor: assert lhs.type.is_block() and rhs.type.is_block() if lhs.type.scalar.is_int(): @@ -986,11 +984,11 @@ def dot(lhs: tl.tensor, else: _0 = builder.get_float32(0) ret_scalar_ty = tl.float32 - M = lhs.type.shape[1 if trans_a else 0] - N = rhs.type.shape[0 if trans_b else 1] + M = lhs.type.shape[0] + N = rhs.type.shape[1] _0 = builder.create_splat(_0, [M, N]) ret_ty = tl.block_type(ret_scalar_ty, [M, N]) - return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32, trans_a, trans_b), + return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), ret_ty) diff --git a/test/Triton/combine.mlir b/test/Triton/combine.mlir index 50ab4c456..b847f3163 100644 --- a/test/Triton/combine.mlir +++ b/test/Triton/combine.mlir @@ -13,10 +13,10 @@ func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32 %dot_out = tt.dot %a, %b, %zero {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> - // CHECK-NEXT: %[[res0:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> + // CHECK-NEXT: %[[res0:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> %res0 = arith.addf %dot_out, %d : tensor<128x128xf32> - // CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> + // CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> %res1 = arith.addf %d, %dot_out : tensor<128x128xf32> return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32> From 115cd3ac47d6f684de89d455cfa7a91753003d5b Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 6 Dec 2022 09:57:05 -0800 Subject: [PATCH 3/3] [FRONTEND] Added `reshape` as an alias for `view` (for now) (#956) --- python/triton/language/core.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 97f313de9..e44d706c9 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -758,6 +758,10 @@ def view(input, shape, _builder=None): shape = [x.value for x in shape] return semantic.view(input, shape, _builder) +@builtin +def reshape(input, shape, _builder=None): + # TODO: should be more than just a view + return view(input, shape, _builder) # ----------------------- # Linear Algebra