From 1ed6ee34baf6c9eb7779f359819e3ccb101a85a8 Mon Sep 17 00:00:00 2001 From: Superjomn Date: Fri, 4 Nov 2022 16:54:05 +0800 Subject: [PATCH] finish coding --- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 160 ++++++++++++++---- lib/Dialect/TritonGPU/IR/Dialect.cpp | 10 +- test/Conversion/tritongpu_to_llvm.mlir | 9 +- 3 files changed, 135 insertions(+), 44 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 3d4148e7b..928536f6e 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -3654,18 +3654,6 @@ public: return operand.getElementType().isF32(); } - SmallVector getOrder() const { - SmallVector order(2); - if (mmaLayout.getVersion() == 1) - order = {0, 1}; - else if (mmaLayout.getVersion() == 0) - order = {1, 0}; - else { - assert(false && "Unexpected MMA version found."); - } - return order; - } - Value loadA(Value tensor, Value llTensor, Value threadId, Location loc, Value smem, ConversionPatternRewriter &rewriter) const { @@ -3674,7 +3662,6 @@ public: auto aShape = tensorTy.getShape(); auto aLayout = tensorTy.getEncoding().cast(); auto aOrder = aLayout.getOrder(); - auto order = getOrder(); bool isARow = aOrder[0] == 1; @@ -3688,14 +3675,115 @@ public: int aNumPtr = 8; int bNumPtr = 8; int aVec = 2; - int NK = aShape[isARow ? 1 : 0]; - return Value{}; + Value _0 = i32_val(0); + Value _1 = i32_val(1); + + Value mContig = _1; + Value nContig = _1; + + Value offA0 = isARow ? _0 : mul(threadId, mContig); + Value offA1 = isARow ? mul(threadId, mContig) : _0; + SmallVector aOff(aNumPtr); + for (int i = 0; i < aNumPtr; ++i) { + aOff[i] = + add(mul(offA0, i32_val(strideA0)), mul(offA1, i32_val(strideA1))); + } + + Type f32PtrTy = ptr_ty(f32_ty); + SmallVector aPtrs(aNumPtr); + for (int i = 0; i < aNumPtr; ++i) + aPtrs[i] = gep(f32PtrTy, llTensor, aOff[i]); + + ValueTable has; + + auto aShapePerCTA = getShapePerCTA(aLayout); + auto sizePerThread = getSizePerThread(aLayout); + int M = isARow ? aShape[0] : aShape[1]; + int K = isARow ? aShape[1] : aShape[0]; + + for (unsigned k = 0; k < K; k++) + for (unsigned m = 0; m < M; m += aShapePerCTA[aOrder[1]]) + for (unsigned mm = 0; mm < sizePerThread[aOrder[1]]; ++mm) { + Value pa = gep(f32PtrTy, aPtrs[0], + i32_val((m + mm) * strideAM + k * strideAK)); + Value va = load(pa); + has[{m + mm, k}] = va; + } + + SmallVector values; + for (auto &item : has) + values.push_back(item.second); + Type structTy = + struct_ty(SmallVector(values.size(), values[0].getType())); + + return getStructFromElements(loc, values, rewriter, structTy); } Value loadB(Value tensor, Value llTensor, Value threadId, Location loc, Value smem, ConversionPatternRewriter &rewriter) const { - return Value{}; + + auto *ctx = rewriter.getContext(); + auto tensorTy = tensor.getType().cast(); + auto bShape = tensorTy.getShape(); + auto bLayout = tensorTy.getEncoding().cast(); + auto bOrder = bLayout.getOrder(); + + bool isBRow = bOrder[0] == 1; + + int strideBN = isBRow ? 1 : bShape[0]; + int strideBK = isBRow ? bShape[1] : 1; + int strideB0 = isBRow ? strideBN : strideBK; + int strideB1 = isBRow ? strideBK : strideBN; + int ldb = isBRow ? strideBK : strideBN; + int bPerPhase = bLayout.getPerPhase(); + int bMaxPhase = bLayout.getMaxPhase(); + int bNumPtr = 8; + int bVec = 4; + + auto bShapePerCTA = getShapePerCTA(bLayout); + auto sizePerThread = getSizePerThread(bLayout); + + Value _0 = i32_val(0); + Value _1 = i32_val(1); + + Value mContig = _1; + Value nContig = _1; + + Value offB0 = isBRow ? mul(threadId, nContig) : _0; + Value offB1 = isBRow ? _0 : mul(threadId, nContig); + SmallVector bOff(bNumPtr); + for (int i = 0; i < bNumPtr; ++i) { + bOff[i] = + add(mul(offB0, i32_val(strideB0)), mul(offB1, i32_val(strideB1))); + } + + Type f32PtrTy = ptr_ty(f32_ty); + SmallVector bPtrs(bNumPtr); + for (int i = 0; i < bNumPtr; ++i) + bPtrs[i] = gep(f32PtrTy, llTensor, bOff[i]); + + ValueTable hbs; + + int K = isBRow ? bShape[0] : bShape[1]; + int N = isBRow ? bShape[1] : bShape[0]; + + for (int k = 0; k < K; ++k) + for (unsigned n = 0; n < N; n += bShapePerCTA[bOrder[0]]) + for (unsigned nn = 0; nn < sizePerThread[bOrder[0]]; ++nn) { + Value pb = gep(f32PtrTy, bPtrs[0], + i32_val((n + nn) * strideBN + k * strideBK)); + Value vb = load(pb); + hbs[{n + nn, k}] = vb; + } + + SmallVector values; + for (auto &item : hbs) + values.push_back(item.second); + Type structTy = + struct_ty(SmallVector(values.size(), values[0].getType())); + + return getStructFromElements(loc, values, rewriter, structTy); } ValueTable extractLoadedOperand(Value llTensor) const { return ValueTable{}; } @@ -3738,18 +3826,15 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand( rewriter, getTypeConverter(), op.getLoc()); - if (dotOperandLayout.getOpIdx() == 0) { - // operand $a + if (dotOperandLayout.getOpIdx() == 0) { // operand $a res = mmaHelper.loadA(src, adaptor.src()); - } else if (dotOperandLayout.getOpIdx() == 1) { - // operand $b + } else if (dotOperandLayout.getOpIdx() == 1) { // operand $b res = mmaHelper.loadB(src, adaptor.src()); } } else if (!isOuter && mmaLayout.getVersion() == 1 && isHMMA) { // tensor core v1 DotOpMmaV1ConversionHelper helper(mmaLayout); - if (dotOperandLayout.getOpIdx() == 0) { - // operand $a + if (dotOperandLayout.getOpIdx() == 0) { // operand $a res = helper.loadA(src, adaptor.src(), getThreadId(rewriter, loc), adaptor.src(), loc, rewriter); } else if (dotOperandLayout.getOpIdx() == 1) { @@ -3758,7 +3843,14 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand( adaptor.src(), loc, rewriter); } } else if (DotOpFMAConversionHelper::useFMA(dstTensorTy)) { // fmadot - + DotOpMmaV1ConversionHelper helper(mmaLayout); + if (dotOperandLayout.getOpIdx() == 0) { // operand $a + res = helper.loadA(src, adaptor.src(), getThreadId(rewriter, loc), + adaptor.src(), loc, rewriter); + } else if (dotOperandLayout.getOpIdx() == 1) { // operand $b + res = helper.loadB(src, adaptor.src(), getThreadId(rewriter, loc), + adaptor.src(), loc, rewriter); + } } else { assert(false && "Unsupported mma layout found"); } @@ -4245,26 +4337,20 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor, auto aLayout = aTensorTy.getEncoding().cast(); auto bLayout = bTensorTy.getEncoding().cast(); - auto cLayout = cTensorTy.getEncoding().cast(); - auto dLayout = dTensorTy.getEncoding().cast(); + auto cLayout = cTensorTy.getEncoding().cast(); + auto dLayout = dTensorTy.getEncoding().cast(); auto aOrder = aLayout.getOrder(); auto bOrder = bLayout.getOrder(); - // According to the original logic, if target.sm < 80, get a {0,1} or get a - // {1,0} - SmallVector order(2); - if (dLayout.getVersion() == 1) - order = {0, 1}; - else - order = {1, 0}; + auto order = dLayout.getOrder(); bool isARow = aOrder[0] == 1; bool isBRow = bOrder[0] == 1; int strideAM = isARow ? aShape[1] : 1; int strideAK = isARow ? 1 : aShape[0]; - int strideBN = isBRow ? 1 : aShape[0]; + int strideBN = isBRow ? 1 : bShape[0]; int strideBK = isBRow ? bShape[1] : 1; int strideA0 = isARow ? strideAK : strideAM; int strideA1 = isARow ? strideAM : strideAK; @@ -4315,9 +4401,9 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor, bPtrs[i] = gep(f32PtrTy, adaptor.b(), bOff[i]); // TODO initialize ret with $c. - std::map, Value> has, hbs; + DotOpFMAConversionHelper::ValueTable has, hbs; auto cc = getElementsFromStruct(loc, adaptor.c(), rewriter); - SmallVector ret(cShape[0] * cShape[1], cc[0]); + SmallVector ret = cc; for (unsigned k = 0; k < NK; k++) { int z = 0; @@ -4982,8 +5068,8 @@ void ConvertTritonGPUToLLVM::initSharedMemory( OpBuilder b(mod.getBodyRegion()); auto loc = mod.getLoc(); auto elemTy = typeConverter.convertType(b.getIntegerType(8)); - // Set array size 0 and external linkage indicates that we use dynamic shared - // allocation to allow a larger shared memory size for each kernel. + // Set array size 0 and external linkage indicates that we use dynamic + // shared allocation to allow a larger shared memory size for each kernel. auto arrayTy = LLVM::LLVMArrayType::get(elemTy, 0); auto global = b.create( loc, arrayTy, /*isConstant=*/false, LLVM::Linkage::External, diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 96b0925bc..408cf3d29 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -117,10 +117,12 @@ SmallVector getShapePerCTA(const Attribute &layout) { "BlockedEncodingAttr not implemented"); } } else if (auto mmaLayout = layout.dyn_cast()) { - assert(mmaLayout.getVersion() == 2 && - "mmaLayout version = 1 is not implemented yet"); - return {16 * mmaLayout.getWarpsPerCTA()[0], - 8 * mmaLayout.getWarpsPerCTA()[1]}; + if (mmaLayout.getVersion() == 2) + return {16 * mmaLayout.getWarpsPerCTA()[0], + 8 * mmaLayout.getWarpsPerCTA()[1]}; + if (mmaLayout.getVersion() == 1) + return {16 * mmaLayout.getWarpsPerCTA()[0], + 16 * mmaLayout.getWarpsPerCTA()[1]}; } else { assert(0 && "Unimplemented usage of getShapePerCTA"); } diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index eae06074a..18c1c034b 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -797,11 +797,14 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}> module attributes {"triton_gpu.num-warps" = 4 : i32} { - func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, + func @matmul_fmadot(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, %a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) { - %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma> + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> // CHECK: llvm.intr.fmuladd - %28 = tt.dot %a, %b, %cst {allowTF32 = false, transA = false, transB = false} : tensor<32x16xf32, #shared> * tensor<16x32xf32, #shared> -> tensor<16x16xf32, #mma> + %28 = tt.dot %a, %b, %cst {allowTF32 = false, transA = false, transB = false} : tensor<32x16xf32, #shared> * tensor<16x32xf32, #shared> -> tensor<32x32xf32, #blocked> + %30 = tt.splat %ptr : (!tt.ptr) -> tensor<32x1x!tt.ptr, #blocked> + %36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr, #blocked>) -> tensor<32x32x!tt.ptr, #blocked> + tt.store %36, %28 : tensor<32x32xf32, #blocked> return } }