From da2993e1c75f930653cef5489e8b7bbd7dcb6ef3 Mon Sep 17 00:00:00 2001 From: Superjomn Date: Wed, 2 Nov 2022 18:02:49 +0800 Subject: [PATCH] init code --- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 135 +++++++++++++++++- 1 file changed, 131 insertions(+), 4 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index a8da162a3..08b547468 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -2836,10 +2836,7 @@ private: ConversionPatternRewriter &rewriter) const; LogicalResult convertFMADot(triton::DotOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - assert(false && "Not implemented yet."); - return failure(); - } + ConversionPatternRewriter &rewriter) const; }; // Helper for conversion of DotOp with mma, that is sm<80 @@ -4095,6 +4092,136 @@ DotOpMmaV1ConversionHelper::extractLoadedOperand( return rcds; } +LogicalResult +DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto *ctx = rewriter.getContext(); + auto loc = op.getLoc(); + auto threadId = getThreadId(rewriter, loc); + + auto A = op.a(); + auto B = op.b(); + auto C = op.c(); + auto D = op.getResult(); + + auto aTensorTy = A.getType().cast(); + auto bTensorTy = B.getType().cast(); + auto cTensorTy = C.getType().cast(); + auto dTensorTy = D.getType().cast(); + + auto aShape = aTensorTy.getShape(); + auto bShape = bTensorTy.getShape(); + auto cShape = cTensorTy.getShape(); + + auto aLayout = aTensorTy.getEncoding().cast(); + auto bLayout = bTensorTy.getEncoding().cast(); + auto cLayout = cTensorTy.getEncoding().cast(); + auto dLayout = dTensorTy.getEncoding().cast(); + + auto aOrder = aLayout.getOrder(); + auto bOrder = bLayout.getOrder(); + + // According to the original logic, if target.sm < 80, get a {0,1} or get a + // {1,0} + SmallVector order(2); + if (dLayout.getVersion() == 1) + order = {0, 1}; + else + order = {1, 0}; + + bool isARow = aOrder[0] == 1; + bool isBRow = bOrder[0] == 1; + + int strideAM = isARow ? aShape[1] : 1; + int strideAK = isARow ? 1 : aShape[0]; + int strideBN = isBRow ? 1 : aShape[0]; + int strideBK = isBRow ? bShape[1] : 1; + int strideA0 = isARow ? strideAK : strideAM; + int strideA1 = isARow ? strideAM : strideAK; + int strideB0 = isBRow ? strideBN : strideBK; + int strideB1 = isBRow ? strideBK : strideBN; + int lda = isARow ? strideAM : strideAK; + int ldb = isBRow ? strideBK : strideBN; + int aPerPhase = aLayout.getPerPhase(); + int aMaxPhase = aLayout.getMaxPhase(); + int bPerPhase = bLayout.getPerPhase(); + int bMaxPhase = bLayout.getMaxPhase(); + int aNumPtr = 8; + int bNumPtr = 8; + int aVec = 2; + int bVec = 4; + int NK = aShape[isARow ? 1 : 0]; + + auto cShapePerCTA = getShapePerCTA(cLayout); + auto sizePerThread = getSizePerThread(dLayout); + + Value _0 = i32_val(0); + Value _1 = i32_val(1); + + Value mContig = _1; + Value nContig = _1; + + Value offA0 = isARow ? _0 : mul(threadId, mContig); + Value offA1 = isARow ? mul(threadId, mContig) : _0; + SmallVector aOff(aNumPtr); + for (int i = 0; i < aNumPtr; ++i) { + aOff[i] = add(mul(offA0, i32_val(strideA0)), mul(offA1, i32_val(strideA1))); + } + + Value offB0 = isBRow ? mul(threadId, nContig) : _0; + Value offB1 = isBRow ? _0 : mul(threadId, nContig); + SmallVector bOff(bNumPtr); + for (int i = 0; i < bNumPtr; ++i) { + bOff[i] = add(mul(offB0, i32_val(strideB0)), mul(offB1, i32_val(strideB1))); + } + + Type f32PtrTy = ptr_ty(f32_ty); + SmallVector aPtrs(aNumPtr); + for (int i = 0; i < aNumPtr; ++i) + aPtrs[i] = gep(f32PtrTy, adaptor.a(), aOff[i]); + + SmallVector bPtrs(bNumPtr); + for (int i = 0; i < bNumPtr; ++i) + bPtrs[i] = gep(f32PtrTy, adaptor.b(), bOff[i]); + + std::map, Value> has, hbs; + // TODO initialize ret with zeros. + SmallVector ret(NK); + for (unsigned k = 0; k < NK; k++) { + int z = 0; + for (unsigned i = 0; i < cShape[order[1]]; i += cShapePerCTA[order[1]]) + for (unsigned j = 0; j < cShape[order[0]]; j += cShapePerCTA[order[0]]) + for (unsigned ii = 0; ii < sizePerThread[order[1]]; ++ii) + for (unsigned jj = 0; jj < sizePerThread[order[0]]; ++jj) { + unsigned m = order[0] == 1 ? i : j; + unsigned n = order[0] == 1 ? j : i; + unsigned mm = order[0] == 1 ? ii : jj; + unsigned nn = order[0] == 1 ? jj : ii; + if (!has.count({m + mm, k})) { + Value pa = gep(f32PtrTy, aPtrs[0], + i32_val((m + mm) * strideAM + k * strideAK)); + Value va = load(pa); + has[{m + mm, k}] = va; + } + if (!has.count({n + nn, k})) { + Value pb = gep(f32PtrTy, bPtrs[0], + i32_val((n + nn) * strideBN + k * strideBK)); + Value vb = load(pb); + has[{n + nn, k}] = vb; + } + ret[z++] = rewriter.create( + loc, has[{m + mm, k}], hbs[{n + nn, k}], ret[z]); + } + } + + auto res = getStructFromElements( + loc, ret, rewriter, + struct_ty(SmallVector(ret.size(), ret[0].getType()))); + rewriter.replaceOp(op, res); + + return success(); +} + /// ====================== mma codegen end ============================ Value convertSplatLikeOpWithMmaLayout(const MmaEncodingAttr &layout,