From 8edfe813a573c5f209960fd7ba1d96bf7b2b7a25 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sat, 3 Dec 2022 09:58:24 -0800 Subject: [PATCH] [FRONTEND][BACKEND] Added `trans` instruction; made flash attention bwd pass work (#943) --- include/triton/Dialect/Triton/IR/TritonOps.td | 12 + lib/Analysis/Alias.cpp | 2 +- lib/Analysis/Utility.cpp | 2 +- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 34 ++- .../TritonToTritonGPU/TritonToTritonGPU.cpp | 52 +++- lib/Dialect/TritonGPU/Transforms/Combine.cpp | 53 ++++ python/src/triton.cc | 10 + python/tests/test_core.py | 233 +++++++++--------- python/triton/compiler.py | 4 +- python/triton/language/core.py | 4 + python/triton/language/semantic.py | 5 + python/tutorials/06-fused-attention.py | 42 ++-- 12 files changed, 310 insertions(+), 143 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 32512282b..09490952a 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -295,6 +295,18 @@ def TT_CatOp : TT_Op<"cat", [NoSideEffect, let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)"; } +def TT_TransOp : TT_Op<"trans", [NoSideEffect, + SameOperandsAndResultElementType]> { + + let summary = "transpose a tensor"; + + let arguments = (ins TT_Tensor:$src); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)"; +} + // // SPMD Ops // diff --git a/lib/Analysis/Alias.cpp b/lib/Analysis/Alias.cpp index 25ba3aeb0..db01e6fc3 100644 --- a/lib/Analysis/Alias.cpp +++ b/lib/Analysis/Alias.cpp @@ -26,7 +26,7 @@ ChangeResult SharedMemoryAliasAnalysis::visitOperation( // These ops may allocate a new shared memory buffer. auto result = op->getResult(0); // FIXME(Keren): extract and insert are always alias for now - if (isa(op)) { + if (isa(op)) { // extract_slice %src aliasInfo = AliasInfo(operands[0]->getValue()); pessimistic = false; diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index f23b111ec..57b3cfbfa 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -105,7 +105,7 @@ bool maybeSharedAllocationOp(Operation *op) { } bool maybeAliasOp(Operation *op) { - return isa(op) || + return isa(op) || isa(op) || isa(op) || isa(op); } diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index d8d7648c4..84315f963 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -2716,6 +2716,9 @@ public: auto dstSharedLayout = dstTy.getEncoding().cast(); auto inOrd = srcBlockedLayout.getOrder(); auto outOrd = dstSharedLayout.getOrder(); + if (inOrd != outOrd) + llvm_unreachable( + "blocked -> shared with different order not yet implemented"); unsigned inVec = inOrd == outOrd ? srcBlockedLayout.getSizePerThread()[inOrd[0]] : 1; unsigned outVec = dstSharedLayout.getVec(); @@ -2775,7 +2778,8 @@ public: getMultiDimIndex(linearRepIdx, reps, inOrd); for (unsigned linearWordIdx = 0; linearWordIdx < numWordsEachRep; ++linearWordIdx) { - // step 1: recover the multidim_index from the index of input_elements + // step 1: recover the multidim_index from the index of + // input_elements auto multiDimWordIdx = getMultiDimIndex(linearWordIdx, wordsInEachRep, inOrd); SmallVector multiDimIdx(2); @@ -3711,6 +3715,33 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor, /// ====================== mma codegen end ============================ +/// ====================== trans codegen begin ============================ + +struct TransOpConversion + : public ConvertTritonGPUOpToLLVMPattern { + using ConvertTritonGPUOpToLLVMPattern< + triton::TransOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::TransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto srcSmemObj = + getSharedMemoryObjectFromStruct(loc, adaptor.src(), rewriter); + SmallVector dstStrides = {srcSmemObj.strides[1], + srcSmemObj.strides[0]}; + SmallVector dstOffsets = {srcSmemObj.offsets[1], + srcSmemObj.offsets[0]}; + auto dstSmemObj = + SharedMemoryObject(srcSmemObj.base, dstStrides, dstOffsets); + auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter); + rewriter.replaceOp(op, retVal); + return success(); + } +}; + +/// ====================== trans codegen end ============================ + Value convertSplatLikeOpWithMmaLayout(const MmaEncodingAttr &layout, Type resType, Type elemType, Value constVal, @@ -4538,6 +4569,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, patterns.add>(typeConverter, benefit); patterns.add(typeConverter, allocation, smem, benefit); + patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); } diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 563858362..439c3f771 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -252,6 +252,51 @@ struct TritonDotPattern : public OpConversionPattern { } }; +struct TritonTransPattern : public OpConversionPattern { + + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::TransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = adaptor.src(); + auto srcType = src.getType().cast(); + Attribute srcEncoding = srcType.getEncoding(); + if (!srcEncoding) + return failure(); + if (!srcEncoding.isa()) { + // TODO: end-to-end correctness is broken if + // the input is blocked and the output is shared + // with different order. Maybe a backend issue in BlockedToShared? + SmallVector order = {1, 0}; + if (auto srcBlockedEncoding = + srcEncoding.dyn_cast()) + llvm::copy(srcBlockedEncoding.getOrder(), order.begin()); + srcEncoding = + triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, order); + srcType = RankedTensorType::get(srcType.getShape(), + srcType.getElementType(), srcEncoding); + src = rewriter.create(src.getLoc(), srcType, + src); + } + auto srcSharedEncoding = + srcEncoding.cast(); + SmallVector retOrder(srcSharedEncoding.getOrder().begin(), + srcSharedEncoding.getOrder().end()); + SmallVector retShapes(srcType.getShape().begin(), + srcType.getShape().end()); + std::reverse(retOrder.begin(), retOrder.end()); + std::reverse(retShapes.begin(), retShapes.end()); + auto retEncoding = + triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, retOrder); + auto retType = + RankedTensorType::get(retShapes, srcType.getElementType(), retEncoding); + + rewriter.replaceOpWithNewOp(op, retType, src); + return success(); + } +}; + struct TritonLoadPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -390,9 +435,10 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, TritonGenericPattern, TritonGenericPattern, TritonBroadcastPattern, TritonGenericPattern, TritonReducePattern, - TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern, - TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern, - TritonPrintfPattern, TritonAtomicRMWPattern>(typeConverter, context); + TritonTransPattern, TritonExpandDimsPattern, TritonMakeRangePattern, + TritonDotPattern, TritonLoadPattern, TritonStorePattern, + TritonExtElemwisePattern, TritonPrintfPattern, TritonAtomicRMWPattern>( + typeConverter, context); } // diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 6ec2ee127..163d606e1 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -178,6 +178,10 @@ public: !isSharedEncoding(convert.getResult())) { return mlir::failure(); } + if (isSharedEncoding(convert.getOperand()) && + isSharedEncoding(convert.getResult())) { + return mlir::failure(); + } auto srcType = convert.getOperand().getType().cast(); auto srcShared = srcType.getEncoding().dyn_cast(); @@ -661,6 +665,54 @@ SmallVector warpsPerTileV2(triton::DotOp dotOp, } // namespace +class OptimizeBlockedToShared : public mlir::RewritePattern { +public: + OptimizeBlockedToShared(mlir::MLIRContext *context) + : RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1, + context) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto cvt = cast(op); + auto srcType = cvt.getOperand().getType().cast(); + auto dstType = cvt.getResult().getType().cast(); + auto srcBlockedLayout = + srcType.getEncoding().dyn_cast(); + auto dstSharedLayout = + dstType.getEncoding().dyn_cast(); + if (!srcBlockedLayout || !dstSharedLayout) + return failure(); + if (srcBlockedLayout.getOrder() == dstSharedLayout.getOrder()) + return failure(); + // For now only works if single use is transpose + // TODO: rematerialize #shared uses + auto users = op->getUsers(); + if (std::distance(users.begin(), users.end()) != 1 || + !isa(*users.begin())) + return failure(); + + auto tmpShared = triton::gpu::SharedEncodingAttr::get( + op->getContext(), dstSharedLayout.getVec(), + dstSharedLayout.getPerPhase(), dstSharedLayout.getMaxPhase(), + srcBlockedLayout.getOrder()); + auto tmpType = RankedTensorType::get(srcType.getShape(), + srcType.getElementType(), tmpShared); + auto tmpCvt = rewriter.create( + op->getLoc(), tmpType, cvt.getOperand()); + + auto newDstType = RankedTensorType::get( + users.begin()->getResultTypes()[0].cast().getShape(), + srcType.getElementType(), dstSharedLayout); + + auto newTrans = rewriter.create(op->getLoc(), newDstType, + tmpCvt.getResult()); + + rewriter.replaceOp(*users.begin(), newTrans.getResult()); + return success(); + } +}; + class BlockedToMMA : public mlir::RewritePattern { int computeCapability; @@ -755,6 +807,7 @@ public: mlir::RewritePatternSet patterns(context); + patterns.add(context); patterns.add(context); patterns.add(context); patterns.add(context); diff --git a/python/src/triton.cc b/python/src/triton.cc index f7ef61547..f450a6ed3 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1085,6 +1085,16 @@ void init_triton_ir(py::module &&m) { mlir::RankedTensorType::get(shape, lhsType.getElementType()), lhs, rhs); }) + .def("create_trans", + [](mlir::OpBuilder &self, mlir::Value &arg) -> mlir::Value { + auto loc = self.getUnknownLoc(); + auto argType = arg.getType().dyn_cast(); + auto argEltType = argType.getElementType(); + std::vector retShape = argType.getShape(); + std::reverse(retShape.begin(), retShape.end()); + return self.create( + loc, mlir::RankedTensorType::get(retShape, argEltType), arg); + }) .def("create_broadcast", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector &shape) -> mlir::Value { diff --git a/python/tests/test_core.py b/python/tests/test_core.py index 127250afd..cf26d44e0 100644 --- a/python/tests/test_core.py +++ b/python/tests/test_core.py @@ -667,7 +667,6 @@ def test_tensor_atomic_rmw(shape, axis, device="cuda"): tl.atomic_add(Z + off1, z) rs = RandomState(17) x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs) - print(x) # reference result z_ref = np.sum(x, axis=axis, keepdims=False) # triton result @@ -1067,122 +1066,126 @@ def test_permute(dtype_str, shape, perm, device='cuda'): # # --------------- -# @pytest.mark.parametrize("epilogue, allow_tf32, dtype", -# [(epilogue, allow_tf32, dtype) -# for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot'] -# for allow_tf32 in [True, False] -# for dtype in ['float16'] -# if not (allow_tf32 and (dtype in ['float16']))]) -# def test_dot(epilogue, allow_tf32, dtype, device='cuda'): -# cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) -# if cc < 80: -# if dtype == 'int8': -# pytest.skip("Only test int8 on devices with sm >= 80") -# elif dtype == 'float32' and allow_tf32: -# pytest.skip("Only test tf32 on devices with sm >= 80") +@pytest.mark.parametrize("epilogue, allow_tf32, dtype", + [(epilogue, allow_tf32, dtype) + for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot'] + for allow_tf32 in [True, False] + for dtype in ['float16'] + if not (allow_tf32 and (dtype in ['float16']))]) +def test_dot(epilogue, allow_tf32, dtype, device='cuda'): + capability = torch.cuda.get_device_capability() + if capability[0] < 80: + if dtype == 'int8': + pytest.skip("Only test int8 on devices with sm >= 80") + elif dtype == 'float32' and allow_tf32: + pytest.skip("Only test tf32 on devices with sm >= 80") -# M, N, K = 128, 128, 64 -# num_warps = 8 -# trans_a, trans_b = False, False + M, N, K = 64, 64, 64 + num_warps = 4 + trans_a, trans_b = False, False -# # triton kernel -# @triton.jit -# def kernel(X, stride_xm, stride_xk, -# Y, stride_yk, stride_yn, -# W, stride_wn, stride_wl, -# Z, stride_zm, stride_zn, -# BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, -# ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, -# ALLOW_TF32: tl.constexpr, -# DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr, -# TRANS_A: tl.constexpr, TRANS_B: tl.constexpr): -# off_m = tl.arange(0, BLOCK_M) -# off_n = tl.arange(0, BLOCK_N) -# off_l = tl.arange(0, BLOCK_N) -# off_k = tl.arange(0, BLOCK_K) -# Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk -# Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn -# Ws = W + off_n[:, None] * stride_wn + off_l[None, :] * stride_wl -# Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn -# z = tl.dot(tl.load(Xs), tl.load(Ys), trans_a=TRANS_A, trans_b=TRANS_B, allow_tf32=ALLOW_TF32) -# if ADD_MATRIX: -# z += tl.load(Zs) -# if ADD_ROWS: -# ZRs = Z + off_m * stride_zm -# z += tl.load(ZRs)[:, None] -# if ADD_COLS: -# ZCs = Z + off_n * stride_zn -# z += tl.load(ZCs)[None, :] -# if DO_SOFTMAX: -# max = tl.max(z, 1) -# z = z - max[:, None] -# num = tl.exp(z) -# den = tl.sum(num, 1) -# z = num / den[:, None] -# if CHAIN_DOT: -# # tl.store(Zs, z) -# # tl.debug_barrier() -# z = tl.dot(z.to(tl.float16), tl.load(Ws), trans_a=TRANS_A) -# tl.store(Zs, z) -# # input -# rs = RandomState(17) -# x = numpy_random((K, M) if trans_a else (M, K), dtype_str=dtype, rs=rs) * .1 -# y = numpy_random((N, K) if trans_b else (K, N), dtype_str=dtype, rs=rs) * .1 -# w = numpy_random((N, N), dtype_str=dtype, rs=rs) * .1 -# if allow_tf32: -# x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32') -# y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32') -# w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32') -# x_tri = to_triton(x, device=device) -# y_tri = to_triton(y, device=device) -# w_tri = to_triton(w, device=device) -# # triton result -# z = 1 + numpy_random((M, N), dtype_str=dtype, rs=rs) * .1 -# z_tri = to_triton(z, device=device) -# if epilogue == 'trans': -# z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1]) -# pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), -# y_tri, y_tri.stride(0), y_tri.stride(1), -# w_tri, w_tri.stride(0), w_tri.stride(1), -# z_tri, z_tri.stride(0), z_tri.stride(1), -# TRANS_A=trans_a, TRANS_B=trans_b, -# BLOCK_M=M, BLOCK_K=K, BLOCK_N=N, -# ADD_MATRIX=epilogue == 'add-matrix', -# ADD_ROWS=epilogue == 'add-rows', -# ADD_COLS=epilogue == 'add-cols', -# DO_SOFTMAX=epilogue == 'softmax', -# CHAIN_DOT=epilogue == 'chain-dot', -# ALLOW_TF32=allow_tf32, -# num_warps=num_warps) -# # torch result -# x_ref = x.T if trans_a else x -# y_ref = y.T if trans_b else y -# z_ref = np.matmul(x_ref, y_ref) -# if epilogue == 'add-matrix': -# z_ref += z -# if epilogue == 'add-rows': -# z_ref += z[:, 0][:, None] -# if epilogue == 'add-cols': -# z_ref += z[0, :][None, :] -# if epilogue == 'softmax': -# num = np.exp(z_ref - np.max(z_ref, axis=-1, keepdims=True)) -# denom = np.sum(num, axis=-1, keepdims=True) -# z_ref = num / denom -# if epilogue == 'chain-dot': -# z_ref = np.matmul(z_ref.T if trans_a else z_ref, w) -# # compare -# # print(z_ref[:,0], z_tri[:,0]) -# np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) -# # make sure ld/st are vectorized -# ptx = pgm.asm['ptx'] -# assert 'ld.global.v4' in ptx -# assert 'st.global.v4' in ptx -# if allow_tf32: -# assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx -# elif dtype == 'float32': -# assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx -# elif dtype == 'int8': -# assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx + # triton kernel + @triton.jit + def kernel(X, stride_xm, stride_xk, + Y, stride_yk, stride_yn, + W, stride_wn, stride_wl, + Z, stride_zm, stride_zn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, + ALLOW_TF32: tl.constexpr, + DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr, + TRANS_A: tl.constexpr, TRANS_B: tl.constexpr): + off_m = tl.arange(0, BLOCK_M) + off_n = tl.arange(0, BLOCK_N) + off_l = tl.arange(0, BLOCK_N) + off_k = tl.arange(0, BLOCK_K) + Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk + Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn + Ws = W + off_n[:, None] * stride_wn + off_l[None, :] * stride_wl + Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn + x = tl.load(Xs) + y = tl.load(Ys) + x = tl.trans(x) if TRANS_A else x + y = tl.trans(y) if TRANS_B else y + z = tl.dot(x, y, allow_tf32=ALLOW_TF32) + if ADD_MATRIX: + z += tl.load(Zs) + if ADD_ROWS: + ZRs = Z + off_m * stride_zm + z += tl.load(ZRs)[:, None] + if ADD_COLS: + ZCs = Z + off_n * stride_zn + z += tl.load(ZCs)[None, :] + if DO_SOFTMAX: + max = tl.max(z, 1) + z = z - max[:, None] + num = tl.exp(z) + den = tl.sum(num, 1) + z = num / den[:, None] + if CHAIN_DOT: + # tl.store(Zs, z) + # tl.debug_barrier() + z = tl.dot(tl.trans(z.to(tl.float16)), tl.load(Ws)) + tl.store(Zs, z) + # input + rs = RandomState(17) + x = numpy_random((K, M) if trans_a else (M, K), dtype_str=dtype, rs=rs) * .1 + y = numpy_random((N, K) if trans_b else (K, N), dtype_str=dtype, rs=rs) * .1 + w = numpy_random((N, N), dtype_str=dtype, rs=rs) * .1 + if allow_tf32: + x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32') + y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32') + w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32') + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) + w_tri = to_triton(w, device=device) + # triton result + z = 1 + numpy_random((M, N), dtype_str=dtype, rs=rs) * .1 + z_tri = to_triton(z, device=device) + if epilogue == 'trans': + z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1]) + pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), + y_tri, y_tri.stride(0), y_tri.stride(1), + w_tri, w_tri.stride(0), w_tri.stride(1), + z_tri, z_tri.stride(0), z_tri.stride(1), + TRANS_A=trans_a, TRANS_B=trans_b, + BLOCK_M=M, BLOCK_K=K, BLOCK_N=N, + ADD_MATRIX=epilogue == 'add-matrix', + ADD_ROWS=epilogue == 'add-rows', + ADD_COLS=epilogue == 'add-cols', + DO_SOFTMAX=epilogue == 'softmax', + CHAIN_DOT=epilogue == 'chain-dot', + ALLOW_TF32=allow_tf32, + num_warps=num_warps) + # torch result + x_ref = x.T if trans_a else x + y_ref = y.T if trans_b else y + z_ref = np.matmul(x_ref, y_ref) + if epilogue == 'add-matrix': + z_ref += z + if epilogue == 'add-rows': + z_ref += z[:, 0][:, None] + if epilogue == 'add-cols': + z_ref += z[0, :][None, :] + if epilogue == 'softmax': + num = np.exp(z_ref - np.max(z_ref, axis=-1, keepdims=True)) + denom = np.sum(num, axis=-1, keepdims=True) + z_ref = num / denom + if epilogue == 'chain-dot': + z_ref = np.matmul(z_ref.T, w) + # compare + # print(z_ref[:,0], z_tri[:,0]) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) + # make sure ld/st are vectorized + ptx = pgm.asm['ptx'] + assert 'ld.global.v4' in ptx + assert 'st.global.v4' in ptx + if allow_tf32: + assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx + elif dtype == 'float32': + assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx + elif dtype == 'int8': + assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx # def test_dot_without_load(): diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 54687170e..ad238cdc0 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -1367,7 +1367,7 @@ prototype_pattern = { "ptx": ptx_prototype_pattern, } -mlir_arg_type_pattern = r'%\w+: ([^,\s]+)(?: \{\S+ = \S+ : \S+\})?,?' +mlir_arg_type_pattern = r'%\w+: ([^,^\)\s]+)(?: \{\S+ = \S+ : \S+\})?,?' ptx_arg_type_pattern = r"\.param\s+\.(\w+)" arg_type_pattern = { "ttir": mlir_arg_type_pattern, @@ -1424,7 +1424,9 @@ def compile(fn, **kwargs): import re match = re.search(prototype_pattern[ir], src, re.MULTILINE) name, signature = match.group(1), match.group(2) + print(name, signature) types = re.findall(arg_type_pattern[ir], signature) + print(types) param_tys = [convert_type_repr(ty) for ty in types] signature = {k: v for k, v in enumerate(param_tys)} first_stage = list(stages.keys()).index(ir) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 94db39bfc..77f458ba7 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -614,6 +614,7 @@ class tensor: assert False, "unsupported" return ret + # x[:, None, :, None] # x = expand_dims(x, axis=1) # x = expand_dims(x, axis=2) @@ -737,6 +738,9 @@ def broadcast_to(input, shape, _builder=None): """ return semantic.broadcast_impl_shape(input, shape, _builder) +@builtin +def trans(input, _builder=None): + return semantic.trans(input, _builder) @builtin def cat(input, other, _builder=None): diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 36912eb2b..6fc9b33b4 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -502,6 +502,11 @@ def cat(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor: # TODO: check types return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), lhs.type) +def trans(input: tl.tensor, builder: ir.builder) -> tl.tensor: + if len(input.shape) != 2: + raise ValueError("Only 2D tensors can be transposed") + ret_type = tl.block_type(input.type.scalar, [input.shape[1], input.shape[0]]) + return tl.tensor(builder.create_trans(input.handle), ret_type) def broadcast_impl_shape(input: tl.tensor, shape: List[int], diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index e4bc9cb82..14571abb9 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -32,7 +32,7 @@ def _fwd_kernel( offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk - off_k = off_hz * stride_qh + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kk + off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk # Initialize pointers to Q, K, V q_ptrs = Q + off_q @@ -50,7 +50,7 @@ def _fwd_kernel( # -- compute qk ---- k = tl.load(k_ptrs + start_n * stride_kn) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) + qk += tl.dot(q, tl.trans(k)) qk *= sm_scale qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf")) # -- compute m_ij, p, l_ij @@ -165,26 +165,26 @@ def _bwd_kernel( q = tl.load(q_ptrs) # recompute p = softmax(qk, dim=-1).T # NOTE: `do` is pre-divided by `l`; no normalization here - qk = tl.dot(q, k, trans_b=True) + qk = tl.dot(q, tl.trans(k)) qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) m = tl.load(m_ptrs + offs_m_curr) p = tl.exp(qk * sm_scale - m[:, None]) # compute dv do = tl.load(do_ptrs) - dv += tl.dot(p.to(tl.float16), do, trans_a=True) + dv += tl.dot(tl.trans(p.to(tl.float16)), do) # compute dp = dot(v, do) Di = tl.load(D_ptrs + offs_m_curr) dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] - dp += tl.dot(do, v, trans_b=True) + dp += tl.dot(do, tl.trans(v)) # compute ds = p * (dp - delta[:, None]) ds = p * dp * sm_scale # compute dk = dot(ds.T, q) - dk += tl.dot(ds.to(tl.float16), q, trans_a=True) - # # compute dq + dk += tl.dot(tl.trans(ds.to(tl.float16)), q) + # compute dq dq = tl.load(dq_ptrs) dq += tl.dot(ds.to(tl.float16), k) tl.store(dq_ptrs, dq) - # # increment pointers + # increment pointers dq_ptrs += BLOCK_M * stride_qm q_ptrs += BLOCK_M * stride_qm do_ptrs += BLOCK_M * stride_qm @@ -273,7 +273,7 @@ attention = _attention.apply @pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 64)]) def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16): torch.manual_seed(20) - q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.1).requires_grad_() + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_() k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_() v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_() sm_scale = 0.2 @@ -287,23 +287,23 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16): p = torch.softmax(p.float(), dim=-1).half() # p = torch.exp(p) ref_out = torch.matmul(p, v) - # ref_out.backward(dout) - # ref_dv, v.grad = v.grad.clone(), None - # ref_dk, k.grad = k.grad.clone(), None - # ref_dq, q.grad = q.grad.clone(), None + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None # # triton implementation tri_out = attention(q, k, v, sm_scale) # print(ref_out) # print(tri_out) - # tri_out.backward(dout) - # tri_dv, v.grad = v.grad.clone(), None - # tri_dk, k.grad = k.grad.clone(), None - # tri_dq, q.grad = q.grad.clone(), None + tri_out.backward(dout) + tri_dv, v.grad = v.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dq, q.grad = q.grad.clone(), None # compare triton.testing.assert_almost_equal(ref_out, tri_out) - # triton.testing.assert_almost_equal(ref_dv, tri_dv) - # triton.testing.assert_almost_equal(ref_dk, tri_dk) - # triton.testing.assert_almost_equal(ref_dq, tri_dq) + triton.testing.assert_almost_equal(ref_dv, tri_dv) + triton.testing.assert_almost_equal(ref_dk, tri_dk) + triton.testing.assert_almost_equal(ref_dq, tri_dq) BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 # vary seq length for fixed head and batch=4 @@ -350,4 +350,4 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep) return ms -bench_flash_attention.run(save_path='.', print_data=True) \ No newline at end of file +# bench_flash_attention.run(save_path='.', print_data=True) \ No newline at end of file