diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 09490952a..258cd41b2 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -339,7 +339,7 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect, $d = matrix_multiply($a, $b) + $c }]; - let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32, BoolAttr:$transA, BoolAttr:$transB); + let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32); let results = (outs TT_FpIntTensor:$d); diff --git a/lib/Conversion/TritonGPUToLLVM/DotHelpers.h b/lib/Conversion/TritonGPUToLLVM/DotHelpers.h index e232d6fa8..3943bc1b8 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotHelpers.h +++ b/lib/Conversion/TritonGPUToLLVM/DotHelpers.h @@ -942,11 +942,6 @@ struct MMA16816ConversionHelper { SmallVector shape(aTensorTy.getShape().begin(), aTensorTy.getShape().end()); - // TODO[Superjomn]: transA cannot be accessed in ConvertLayoutOp. - bool transA = false; - if (transA) { - std::swap(shape[0], shape[1]); - } ValueTable ha; std::function loadFn; @@ -1052,8 +1047,6 @@ struct MMA16816ConversionHelper { SmallVector aShape(aTensorTy.getShape().begin(), aTensorTy.getShape().end()); - if (op.transA()) - std::swap(aShape[0], aShape[1]); auto dShape = dTensorTy.getShape(); @@ -1462,8 +1455,8 @@ Value DotOpMmaV1ConversionHelper::loadB( SmallVector order(sharedLayout.getOrder().begin(), sharedLayout.getOrder().end()); - Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter); + Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter); bool isBRow = order[0] != 0; bool isBVec4 = isBRow && shape[order[0]] <= 16; // TODO[Superjomn]: Support the case when isBVec4=false later diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 2bfbbb090..33fd3d932 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -3253,7 +3253,7 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { // Here we assume the DotOp's operands always comes from shared memory. auto AShape = A.getType().cast().getShape(); - size_t reduceAxis = op.transA() ? 0 : 1; + size_t reduceAxis = 1; unsigned K = AShape[reduceAxis]; bool isOuter = K == 1; @@ -3492,22 +3492,10 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor, auto DShape = DTensorTy.getShape(); auto wpt = mmaLayout.getWarpsPerCTA(); - bool transA = op.transA(); - bool transB = op.transB(); - // TODO[Superjomn]: order cannot accessed in DotOp. SmallVector AOrder({1, 0}); SmallVector BOrder({1, 0}); - if (transA) { - std::swap(AShape[0], AShape[1]); - std::swap(AOrder[0], AOrder[1]); - } - if (transB) { - std::swap(BShape[0], BShape[1]); - std::swap(BOrder[0], BOrder[0]); - } - bool isARow = AOrder[0] != 0; bool isBRow = BOrder[0] != 0; bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 439c3f771..e4a3e8064 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -245,9 +245,8 @@ struct TritonDotPattern : public OpConversionPattern { bType.getElementType(), encoding); b = rewriter.create(b.getLoc(), dstType, b); } - rewriter.replaceOpWithNewOp( - op, retType, a, b, adaptor.c(), adaptor.allowTF32(), adaptor.transA(), - adaptor.transB()); + rewriter.replaceOpWithNewOp(op, retType, a, b, adaptor.c(), + adaptor.allowTF32()); return success(); } }; diff --git a/lib/Dialect/Triton/Transforms/Combine.td b/lib/Dialect/Triton/Transforms/Combine.td index b8caf93ce..1b84cc7f3 100644 --- a/lib/Dialect/Triton/Transforms/Combine.td +++ b/lib/Dialect/Triton/Transforms/Combine.td @@ -12,21 +12,21 @@ include "triton/Dialect/Triton/IR/TritonOps.td" // AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d) // AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d) def CombineDotAddIPattern : Pat< - (Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB)), - (TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB), + (Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)), + (TT_DotOp $a, $b, $d, $allowTF32), [(Constraint> $c)]>; def CombineDotAddFPattern : Pat< - (Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB)), - (TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB), + (Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)), + (TT_DotOp $a, $b, $d, $allowTF32), [(Constraint> $c)]>; def CombineDotAddIRevPattern : Pat< - (Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB), $d), - (TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB), + (Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d), + (TT_DotOp $a, $b, $d, $allowTF32), [(Constraint> $c)]>; def CombineDotAddFRevPattern : Pat< - (Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB), $d), - (TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB), + (Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d), + (TT_DotOp $a, $b, $d, $allowTF32), [(Constraint> $c)]>; diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 2366fe962..e14bae003 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -781,8 +781,7 @@ public: a = rewriter.create(a.getLoc(), newAType, a); b = rewriter.create(b.getLoc(), newBType, b); auto newDot = rewriter.create( - dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.allowTF32(), - dotOp.transA(), dotOp.transB()); + dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.allowTF32()); rewriter.replaceOpWithNewOp( op, oldRetType, newDot.getResult()); diff --git a/python/src/triton.cc b/python/src/triton.cc index 95ba9409f..9d0a7f0d1 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1184,11 +1184,10 @@ void init_triton_ir(py::module &&m) { }) .def("create_dot", [](mlir::OpBuilder &self, mlir::Value &a, mlir::Value &b, - mlir::Value &c, bool allowTF32, bool transA, - bool transB) -> mlir::Value { + mlir::Value &c, bool allowTF32) -> mlir::Value { auto loc = self.getUnknownLoc(); return self.create(loc, c.getType(), a, b, c, - allowTF32, transA, transB); + allowTF32); }) .def("create_exp", [](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value { diff --git a/python/tests/test_core.py b/python/tests/test_core.py index a2175469c..68f78272b 100644 --- a/python/tests/test_core.py +++ b/python/tests/test_core.py @@ -1126,7 +1126,7 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'): if CHAIN_DOT: # tl.store(Zs, z) # tl.debug_barrier() - z = tl.dot(tl.trans(z.to(tl.float16)), tl.load(Ws)) + z = tl.dot(z.to(tl.float16), tl.load(Ws)) tl.store(Zs, z) # input rs = RandomState(17) @@ -1173,7 +1173,7 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'): denom = np.sum(num, axis=-1, keepdims=True) z_ref = num / denom if epilogue == 'chain-dot': - z_ref = np.matmul(z_ref.T, w) + z_ref = np.matmul(z_ref, w) # compare # print(z_ref[:,0], z_tri[:,0]) np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 86b655f21..fbc85eb45 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -762,6 +762,9 @@ class CodeGenerator(ast.NodeVisitor): def visit_Attribute(self, node): lhs = self.visit(node.value) + if isinstance(lhs, triton.language.tensor): + if node.attr == "T": + return triton.language.semantic.trans(lhs, builder=self.builder) return getattr(lhs, node.attr) def visit_Expr(self, node): diff --git a/python/triton/language/core.py b/python/triton/language/core.py index d208cc45e..97f313de9 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -603,10 +603,9 @@ class tensor: assert False, "unsupported" return ret - - # x[:, None, :, None] - # x = expand_dims(x, axis=1) - # x = expand_dims(x, axis=2) + @property + def T(self): + assert False, "Transposition must be created by the AST Visitor" @builtin def to(self, dtype, bitcast=False, _builder=None): @@ -766,7 +765,7 @@ def view(input, shape, _builder=None): @builtin -def dot(input, other, allow_tf32=True, trans_a=False, trans_b=False, _builder=None): +def dot(input, other, allow_tf32=True, _builder=None): """ Returns the matrix product of two blocks. @@ -778,7 +777,7 @@ def dot(input, other, allow_tf32=True, trans_a=False, trans_b=False, _builder=No :type other: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} """ allow_tf32 = _constexpr_to_value(allow_tf32) - return semantic.dot(input, other, allow_tf32, trans_a, trans_b, _builder) + return semantic.dot(input, other, allow_tf32, _builder) # ----------------------- diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 6fc9b33b4..104d5ac54 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -976,8 +976,6 @@ def atomic_xchg(ptr: tl.tensor, def dot(lhs: tl.tensor, rhs: tl.tensor, allow_tf32: bool, - trans_a: bool, - trans_b: bool, builder: ir.builder) -> tl.tensor: assert lhs.type.is_block() and rhs.type.is_block() if lhs.type.scalar.is_int(): @@ -986,11 +984,11 @@ def dot(lhs: tl.tensor, else: _0 = builder.get_float32(0) ret_scalar_ty = tl.float32 - M = lhs.type.shape[1 if trans_a else 0] - N = rhs.type.shape[0 if trans_b else 1] + M = lhs.type.shape[0] + N = rhs.type.shape[1] _0 = builder.create_splat(_0, [M, N]) ret_ty = tl.block_type(ret_scalar_ty, [M, N]) - return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32, trans_a, trans_b), + return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), ret_ty) diff --git a/test/Triton/combine.mlir b/test/Triton/combine.mlir index 50ab4c456..b847f3163 100644 --- a/test/Triton/combine.mlir +++ b/test/Triton/combine.mlir @@ -13,10 +13,10 @@ func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32 %dot_out = tt.dot %a, %b, %zero {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> - // CHECK-NEXT: %[[res0:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> + // CHECK-NEXT: %[[res0:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> %res0 = arith.addf %dot_out, %d : tensor<128x128xf32> - // CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> + // CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> %res1 = arith.addf %d, %dot_out : tensor<128x128xf32> return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32>