From e41978197874358abe12e1ac10b76421749492fe Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Tue, 6 Dec 2022 10:57:08 +0800 Subject: [PATCH] [Triton-MLIR][BACKEND] Make mmav1 works on basic cases (#944) TODO: - Add more cases - Currently, we just set vec to 4 to make the basic cases pass Issue: - the vec in shared layout is different compared to master branch - when vec=1, it encounters CUDA misalignment error, it doesn't work in master branch as well - when setting vec to the value identical to master branch, the MMA works --- .github/workflows/integration-tests.yml | 11 +- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 20 ++- lib/Conversion/TritonGPUToLLVM/DotHelpers.h | 138 ++++++++---------- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 37 +++-- lib/Conversion/TritonGPUToLLVM/Utility.h | 2 + lib/Dialect/TritonGPU/Transforms/Combine.cpp | 1 + python/src/triton.cc | 5 + python/tests/test_gemm.py | 20 ++- 8 files changed, 134 insertions(+), 100 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 370575bbb..d54e7061e 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -17,7 +17,7 @@ jobs: id: set-matrix run: | if [ x"${{ github.repository }}" == x"openai/triton" ]; then - echo '::set-output name=matrix::[["self-hosted", "A10"], "macos-10.15"]' + echo '::set-output name=matrix::[["self-hosted", "A10"], ["self-hosted", "V100"], "macos-10.15"]' else echo '::set-output name=matrix::["ubuntu-latest", "macos-10.15"]' fi @@ -79,11 +79,18 @@ jobs: lit -v "$LIT_TEST_DIR" - name: Run python tests - if: ${{matrix.runner[0] == 'self-hosted'}} + if: ${{matrix.runner[0] == 'self-hosted' && matrix.runner[1] == 'A10'}} run: | cd python/tests pytest + # TODO[Superjomn] Enable all the tests on V100 if available + - name: Run python tests on V100 + if: ${{matrix.runner[0] == 'self-hosted' && matrix.runner[1] == 'V100'}} + run: | + cd python/tests + pytest test_gemm.py::test_gemm_no_scf_for_mmav1 + - name: Run CXX unittests run: | cd python/ diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 7a0cd4324..d4ff8021d 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -87,20 +87,24 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] / // number of rows per phase int perPhase = 128 / (shape[order[0]] * (eltTy.getIntOrFloatBitWidth() / 8)); perPhase = std::max(perPhase, 1); - + // index of the inner dimension in `order` unsigned inner = (opIdx == 0) ? 0 : 1; // ---- begin version 1 ---- if (version == 1) { bool is_row = order[0] != 0; - bool is_vec4 = opIdx == 0 ? is_row && (shape[order[0]] <= 16) : - !is_row && (shape[order[0]] <= 16); + bool is_vec4 = opIdx == 0 ? !is_row && (shape[order[0]] <= 16) : + is_row && (shape[order[0]] <= 16); + // TODO[Superjomn]: Support the case when is_vec4=false later + // Currently, we only support ld.v2, for the mma layout varies with different ld vector width. + is_vec4 = true; int pack_size = opIdx == 0 ? ((is_row || is_vec4) ? 1 : 2) : ((is_row && !is_vec4) ? 2 : 1); int rep = 2 * pack_size; int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase; - return $_get(context, 2 * rep, perPhase, maxPhase, order); + int vec = 2 * rep; + return $_get(context, vec, perPhase, maxPhase, order); } // ---- begin version 2 ---- @@ -110,14 +114,14 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] / // for now, disable swizzle when using transposed int8 tensor cores if (eltTy.isInteger(8) && order[0] == inner) return $_get(context, 1, 1, 1, order); - + // --- handle A operand --- if (opIdx == 0) { // compute swizzling for A operand int vec = (order[0] == 1) ? matShape[2] : matShape[0]; // k : m int mmaStride = (order[0] == 1) ? matShape[0] : matShape[2]; int maxPhase = mmaStride / perPhase; return $_get(context, vec, perPhase, maxPhase, order); - } + } // --- handle B operand --- if (opIdx == 1) { @@ -125,8 +129,8 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] / int mmaStride = (order[0] == 1) ? matShape[2] : matShape[1]; int maxPhase = mmaStride / perPhase; return $_get(context, vec, perPhase, maxPhase, order); - } - + } + llvm_unreachable("invalid operand index"); } diff --git a/lib/Conversion/TritonGPUToLLVM/DotHelpers.h b/lib/Conversion/TritonGPUToLLVM/DotHelpers.h index d354f0f4d..e232d6fa8 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotHelpers.h +++ b/lib/Conversion/TritonGPUToLLVM/DotHelpers.h @@ -39,15 +39,6 @@ using ::mlir::triton::gpu::DotOperandEncodingAttr; using ::mlir::triton::gpu::MmaEncodingAttr; using ::mlir::triton::gpu::SharedEncodingAttr; -// Forward declaration necessary functions locates in TritonGPUToLLVM.cpp . -llvm::SmallVector -getElementsFromStruct(mlir::Location loc, mlir::Value llvmStruct, - mlir::ConversionPatternRewriter &rewriter); - -mlir::LLVM::SharedMemoryObject -getSharedMemoryObjectFromStruct(mlir::Location loc, mlir::Value llvmStruct, - mlir::ConversionPatternRewriter &rewriter); - // Helper for conversion of DotOp with mma, that is sm<80 struct DotOpMmaV1ConversionHelper { MmaEncodingAttr mmaLayout; @@ -710,17 +701,13 @@ public: if (kOrder == 1) { elems[0] = load(gep(elemPtrTy, ptr, sOffsetElemVal)); elems[1] = load(gep(elemPtrTy, ptr2, sOffsetElemVal)); - elems[2] = - load(gep(elemPtrTy, ptr, sOffsetArrElemVal)); - elems[3] = - load(gep(elemPtrTy, ptr2, sOffsetArrElemVal)); + elems[2] = load(gep(elemPtrTy, ptr, sOffsetArrElemVal)); + elems[3] = load(gep(elemPtrTy, ptr2, sOffsetArrElemVal)); } else { elems[0] = load(gep(elemPtrTy, ptr, sOffsetElemVal)); elems[2] = load(gep(elemPtrTy, ptr2, sOffsetElemVal)); - elems[1] = - load(gep(elemPtrTy, ptr, sOffsetArrElemVal)); - elems[3] = - load(gep(elemPtrTy, ptr2, sOffsetArrElemVal)); + elems[1] = load(gep(elemPtrTy, ptr, sOffsetArrElemVal)); + elems[3] = load(gep(elemPtrTy, ptr2, sOffsetArrElemVal)); } return {elems[0], elems[1], elems[2], elems[3]}; @@ -952,7 +939,6 @@ struct MMA16816ConversionHelper { // Loading $a from smem to registers, returns a LLVM::Struct. Value loadA(Value tensor, const SharedMemoryObject &smemObj) const { auto aTensorTy = tensor.getType().cast(); - auto layout = aTensorTy.getEncoding().cast(); SmallVector shape(aTensorTy.getShape().begin(), aTensorTy.getShape().end()); @@ -973,12 +959,13 @@ struct MMA16816ConversionHelper { if (aTensorTy.getEncoding().isa()) { Value warpM = getWarpM(shape[0]); // load from smem - int wpt = std::min(mmaLayout.getWarpsPerCTA()[0], shape[0] / matShapeM); - loadFn = getLoadMatrixFn( - tensor, smemObj, mmaLayout, wpt /*wpt*/, - 1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShape*/, - {matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/, - true /*isA*/); + int wpt = + std::min(mmaLayout.getWarpsPerCTA()[0], shape[0] / matShapeM); + loadFn = + getLoadMatrixFn(tensor, smemObj, mmaLayout, wpt /*wpt*/, 1 /*kOrder*/, + {mmaInstrM, mmaInstrK} /*instrShape*/, + {matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, + ha /*vals*/, true /*isA*/); } else if (aTensorTy.getEncoding().isa()) { // load from registers, used in gemm fuse // TODO(Superjomn) Port the logic. @@ -1000,7 +987,6 @@ struct MMA16816ConversionHelper { Value loadB(Value tensor, const SharedMemoryObject &smemObj) { ValueTable hb; auto tensorTy = tensor.getType().cast(); - auto layout = tensorTy.getEncoding().cast(); SmallVector shape(tensorTy.getShape().begin(), tensorTy.getShape().end()); @@ -1017,12 +1003,13 @@ struct MMA16816ConversionHelper { int numRepN = getNumRepN(tensorTy, shape[1]); Value warpN = getWarpN(shape[1]); - int wpt = std::min(mmaLayout.getWarpsPerCTA()[1], shape[1] / matShapeN); - auto loadFn = getLoadMatrixFn( - tensor, smemObj, mmaLayout, wpt /*wpt*/, - 0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShape*/, - {matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/, - false /*isA*/); + int wpt = + std::min(mmaLayout.getWarpsPerCTA()[1], shape[1] / matShapeN); + auto loadFn = + getLoadMatrixFn(tensor, smemObj, mmaLayout, wpt /*wpt*/, 0 /*kOrder*/, + {mmaInstrK, mmaInstrN} /*instrShape*/, + {matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, + hb /*vals*/, false /*isA*/); for (int n = 0; n < std::max(numRepN / 2, 1); ++n) { for (int k = 0; k < numRepK; ++k) @@ -1167,6 +1154,7 @@ private: SmallVector ptrs(numPtrs); Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter); + Type smemPtrTy = helper.getShemPtrTy(); for (int i = 0; i < numPtrs; ++i) { ptrs[i] = @@ -1292,7 +1280,6 @@ struct DotOpFMAConversionHelper { auto blockedLayout = dotOpLayout.getParent().cast(); auto shapePerCTA = getShapePerCTA(blockedLayout); auto sizePerThread = getSizePerThread(blockedLayout); - auto order = blockedLayout.getOrder(); // TODO[Superjomn]: we assume the k aixs is fixed for $a and $b here, fix it // if not. @@ -1342,17 +1329,15 @@ Value DotOpMmaV1ConversionHelper::loadA( SmallVector order(sharedLayout.getOrder().begin(), sharedLayout.getOrder().end()); - // TODO [Superjomn]: transA cannot be accessed in ConvertLayoutOp. - bool transA = false; - if (transA) { - std::swap(shape[0], shape[1]); - std::swap(order[0], order[1]); - } - Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); + Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter); bool isARow = order[0] != 0; bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes + // TODO[Superjomn]: Support the case when isAVec4=false later + // Currently, we only support ld.v2, for the mma layout varies with different + // ld vector width. + isAVec4 = true; int packSize0 = (isARow || isAVec4) ? 1 : 2; SmallVector fpw({2, 2, 1}); @@ -1362,6 +1347,16 @@ Value DotOpMmaV1ConversionHelper::loadA( SmallVector rep({repM, 0, repK}); // pad N with 0 SmallVector spw({spwM, 0, 1}); // pad N with 0 + auto [offsetAM, offsetAK, _0, _1] = + computeOffsets(thread, isARow, false, fpw, spw, rep, rewriter, loc); + // TODO [Superjomn]: transA cannot be accessed in ConvertLayoutOp. + bool transA = false; + if (transA) { + std::swap(shape[0], shape[1]); + std::swap(offsetAM, offsetAK); + std::swap(order[0], order[1]); + } + int vecA = sharedLayout.getVec(); auto strides = smemObj.strides; @@ -1373,9 +1368,6 @@ Value DotOpMmaV1ConversionHelper::loadA( int strideRepM = wpt[0] * fpw[0] * 8; int strideRepK = 1; - auto [offsetAM, offsetAK, _0, _1] = - computeOffsets(thread, isARow, false, fpw, spw, rep, rewriter, loc); - // swizzling int perPhaseA = sharedLayout.getPerPhase(); int maxPhaseA = sharedLayout.getMaxPhase(); @@ -1398,19 +1390,14 @@ Value DotOpMmaV1ConversionHelper::loadA( } Type f16x2Ty = vec_ty(f16_ty, 2); - // One thread get 8 elements as result - Type retTy = - LLVM::LLVMStructType::getLiteral(ctx, SmallVector(8, type::f32Ty(ctx))); // prepare arguments SmallVector ptrA(numPtrA); std::map, std::pair> has; - auto smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter); for (int i = 0; i < numPtrA; i++) - ptrA[i] = gep(ptr_ty(f16_ty), smem, offA[i]); + ptrA[i] = gep(ptr_ty(f16_ty), smemBase, offA[i]); - auto instrShape = getMmaInstrShape(); unsigned numM = std::max(rep[0] * shape[0] / (spw[0] * wpt[0]), 1); Type f16PtrTy = ptr_ty(f16_ty); @@ -1420,7 +1407,7 @@ Value DotOpMmaV1ConversionHelper::loadA( }; auto loadA = [&](int m, int k) { int offidx = (isARow ? k / 4 : m) % numPtrA; - Value thePtrA = gep(f16PtrTy, smem, offA[offidx]); + Value thePtrA = gep(f16PtrTy, smemBase, offA[offidx]); int stepAM = isARow ? m : m / numPtrA * numPtrA; int stepAK = isARow ? k / (numPtrA * vecA) * (numPtrA * vecA) : k; @@ -1446,12 +1433,10 @@ Value DotOpMmaV1ConversionHelper::loadA( for (unsigned k = 0; k < NK; k += 4) for (unsigned m = 0; m < numM / 2; ++m) - if (!has.count({m, k})) - loadA(m, k); + loadA(m, k); SmallVector elems; elems.reserve(has.size() * 2); - auto vecTy = vec_ty(f16_ty, 2); for (auto item : has) { // has is a map, the key should be ordered. elems.push_back(item.second.first); elems.push_back(item.second.second); @@ -1466,7 +1451,6 @@ Value DotOpMmaV1ConversionHelper::loadB( Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc, ConversionPatternRewriter &rewriter) const { // smem - Value smem = smemObj.base; auto strides = smemObj.strides; auto *ctx = rewriter.getContext(); @@ -1478,21 +1462,20 @@ Value DotOpMmaV1ConversionHelper::loadB( SmallVector order(sharedLayout.getOrder().begin(), sharedLayout.getOrder().end()); - // TODO [Superjomn]: transB cannot be accessed in ConvertLayoutOp. - bool transB = false; - - if (transB) { - std::swap(order[0], order[1]); - std::swap(shape[0], shape[1]); - } + Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter); bool isBRow = order[0] != 0; bool isBVec4 = isBRow && shape[order[0]] <= 16; + // TODO[Superjomn]: Support the case when isBVec4=false later + // Currently, we only support ld.v2, for the mma layout varies with different + // ld vector width. + isBVec4 = true; int packSize1 = (isBRow && !isBVec4) ? 2 : 1; SmallVector fpw({2, 2, 1}); SmallVector rep({0, 2 * packSize1, 1}); // pad M with 0 SmallVector spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0 int vecB = sharedLayout.getVec(); + Value strideBN = isBRow ? i32_val(1) : strides[1]; Value strideBK = isBRow ? strides[0] : i32_val(1); Value strideB0 = isBRow ? strideBN : strideBK; @@ -1500,24 +1483,29 @@ Value DotOpMmaV1ConversionHelper::loadB( int strideRepN = wpt[1] * fpw[1] * 8; int strideRepK = 1; + // TODO [Superjomn]: transB cannot be accessed in ConvertLayoutOp. + bool transB = false; + + auto [_0, _1, offsetBN, offsetBK] = + computeOffsets(thread, false, isBRow, fpw, spw, rep, rewriter, loc); + if (transB) { + std::swap(order[0], order[1]); + std::swap(shape[0], shape[1]); + std::swap(offsetBK, offsetBN); + } + // swizzling - int perPhaseA = sharedLayout.getPerPhase(); - int maxPhaseA = sharedLayout.getMaxPhase(); int perPhaseB = sharedLayout.getPerPhase(); int maxPhaseB = sharedLayout.getMaxPhase(); int stepB0 = isBRow ? strideRepN : strideRepK; int numPtrB = std::max(2 * perPhaseB * maxPhaseB / stepB0, 1); int NK = shape[0]; - auto [_0, _1, offsetBN, offsetBK] = - computeOffsets(thread, false, isBRow, fpw, spw, rep, rewriter, loc); - if (transB) - std::swap(offsetBK, offsetBN); - Value offB0 = isBRow ? offsetBN : offsetBK; Value offB1 = isBRow ? offsetBK : offsetBN; Value phaseB = urem(udiv(offB1, i32_val(perPhaseB)), i32_val(maxPhaseB)); Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); + offB0 = add(offB0, cSwizzleOffset); SmallVector offB(numPtrB); for (int i = 0; i < numPtrB; ++i) { @@ -1549,6 +1537,7 @@ Value DotOpMmaV1ConversionHelper::loadB( Value offset = add(mul(i32_val(stepBN * strideRepN), strideBN), mul(i32_val(stepBK), strideBK)); Value pb = gep(f16PtrTy, thePtrB, offset); + Value hb = load(bitcast(pb, ptr_ty(vec_ty(i32_ty, std::max(vecB / 2, 1)), 3))); // record lds that needs to be moved @@ -1651,9 +1640,12 @@ DotOpMmaV1ConversionHelper::extractLoadedOperand( SmallVector elems = getElementsFromStruct(llStruct.getLoc(), llStruct, rewriter); - for (int k = 0, offset = 0, i = 0; k < NK && offset < elems.size(); - k += 4, i++, offset += 2) { - rcds[{i, k}] = std::make_pair(elems[offset], elems[offset + 1]); + int offset = 0; + for (int i = 0; offset < elems.size(); ++i) { + for (int k = 0; k < NK; k += 4) { + rcds[{i, k}] = std::make_pair(elems[offset], elems[offset + 1]); + offset += 2; + } } return rcds; @@ -1675,9 +1667,7 @@ Value DotOpFMAConversionHelper::loadA( int strideAK = isARow ? 1 : aShape[0]; int strideA0 = isARow ? strideAK : strideAM; int strideA1 = isARow ? strideAM : strideAK; - int lda = isARow ? strideAM : strideAK; int aNumPtr = 8; - int bNumPtr = 8; int NK = aShape[1]; auto shapePerCTA = getShapePerCTA(dLayout); @@ -1686,13 +1676,11 @@ Value DotOpFMAConversionHelper::loadA( Value _0 = i32_val(0); Value mContig = i32_val(sizePerThread[order[1]]); - Value nContig = i32_val(sizePerThread[order[0]]); // threadId in blocked layout auto threadIds = getThreadIds(thread, shapePerCTA, order, rewriter, loc); Value threadIdM = threadIds[0]; - Value threadIdN = threadIds[1]; Value offA0 = isARow ? _0 : mul(threadIdM, mContig); Value offA1 = isARow ? mul(threadIdM, mContig) : _0; @@ -1745,7 +1733,6 @@ Value DotOpFMAConversionHelper::loadB( int strideBK = isBRow ? bShape[1] : 1; int strideB0 = isBRow ? strideBN : strideBK; int strideB1 = isBRow ? strideBK : strideBN; - int ldb = isBRow ? strideBK : strideBN; int bNumPtr = 8; int NK = bShape[0]; @@ -1754,7 +1741,6 @@ Value DotOpFMAConversionHelper::loadB( Value _0 = i32_val(0); - Value mContig = i32_val(sizePerThread[order[1]]); Value nContig = i32_val(sizePerThread[order[0]]); // threadId in blocked layout diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 19cfce7e0..2bfbbb090 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -62,12 +62,11 @@ namespace LLVM { static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; } // A helper function for using printf in LLVM conversion. -void llPrintf(StringRef msg, ValueRange args, - ConversionPatternRewriter &rewriter); +void vprintf(StringRef msg, ValueRange args, + ConversionPatternRewriter &rewriter); -// Helper function -#define tid_val() getThreadId(rewriter, loc) -#define llprintf(fmt, ...) LLVM::llPrintf(fmt, {__VA_ARGS__}, rewriter) +void vprintf_array(Value thread, ArrayRef arr, std::string info, + std::string elem_repr, ConversionPatternRewriter &builder); } // namespace LLVM } // namespace mlir @@ -3537,8 +3536,8 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor, SmallVector resVals(resSize); auto callMMA = [&](unsigned m, unsigned n, unsigned k) { - auto ha = has[{m, k}]; - auto hb = hbs[{n, k}]; + auto ha = has.at({m, k}); + auto hb = hbs.at({n, k}); std::vector idx{{ (m * 2 + 0) + (n * 4 + 0) * numM, // row0 (m * 2 + 0) + (n * 4 + 1) * numM, @@ -3554,13 +3553,13 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor, auto *resOprs = builder.newListOperand(8, "=f"); auto *AOprs = builder.newListOperand({ - {ha.first, "f"}, - {ha.second, "f"}, + {ha.first, "r"}, + {ha.second, "r"}, }); auto *BOprs = builder.newListOperand({ - {hb.first, "f"}, - {hb.second, "f"}, + {hb.first, "r"}, + {hb.second, "r"}, }); auto *COprs = builder.newListOperand(); for (int i = 0; i < 8; ++i) @@ -4806,11 +4805,23 @@ namespace mlir { namespace LLVM { -void llPrintf(StringRef msg, ValueRange args, - ConversionPatternRewriter &rewriter) { +void vprintf(StringRef msg, ValueRange args, + ConversionPatternRewriter &rewriter) { PrintfOpConversion::llPrintf(msg, args, rewriter); } +void vprintf_array(Value thread, ArrayRef arr, std::string info, + std::string elem_repr, ConversionPatternRewriter &builder) { + std::string fmt = info + " t-%d "; + std::vector new_arr({thread}); + for (int i = 0; i < arr.size(); ++i) { + fmt += elem_repr + ((i == arr.size() - 1) ? "" : ", "); + new_arr.push_back(arr[i]); + } + + vprintf(fmt, new_arr, builder); +} + } // namespace LLVM TritonLLVMConversionTarget::TritonLLVMConversionTarget( diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.h b/lib/Conversion/TritonGPUToLLVM/Utility.h index 323be4827..82ccc3fe6 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.h +++ b/lib/Conversion/TritonGPUToLLVM/Utility.h @@ -111,6 +111,8 @@ LLVM::createIndexConstant(rewriter, loc, this->getTypeConverter(), \ __VA_ARGS__) +#define tid_val() getThreadId(rewriter, loc) + namespace mlir { namespace LLVM { using namespace mlir::triton; diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 163d606e1..2366fe962 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -756,6 +756,7 @@ public: auto mod = op->getParentOfType(); int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); int version = computeCapabilityToMMAVersion(computeCapability); + auto newRetType = RankedTensorType::get( retShape, oldRetType.getElementType(), triton::gpu::MmaEncodingAttr::get( diff --git a/python/src/triton.cc b/python/src/triton.cc index f450a6ed3..95ba9409f 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1383,6 +1383,11 @@ void init_triton_translation(py::module &m) { llvm::SMDiagnostic error; std::unique_ptr module = llvm::parseIR(buffer->getMemBufferRef(), error, context); + if (!module) + llvm::report_fatal_error( + "failed to parse IR: " + error.getMessage() + + "lineno: " + std::to_string(error.getLineNo())); + // translate module to PTX auto ptxCode = triton::translateLLVMIRToPTX(*module, capability, version); diff --git a/python/tests/test_gemm.py b/python/tests/test_gemm.py index 7c8c4226b..3555359dc 100644 --- a/python/tests/test_gemm.py +++ b/python/tests/test_gemm.py @@ -172,7 +172,7 @@ def get_proper_err(a, b, golden): [128, 64, 128, 4, 128, 64, 128, False, False], [16, 16, 16, 16, 16, 16, 16, False, False], # wpt overflow issue # K-Forloop - #[16, 16, 64, 4, 8, 8, 8, False, False], # Wrap threads + # [16, 16, 64, 4, 8, 8, 8, False, False], # Wrap threads [32, 32, 64, 4, 32, 32, 32, False, False], # Single shared encoding [16, 16, 128, 4, 16, 16, 16, False, False], # Single shared encoding and small k [64, 32, 128, 4, 64, 32, 64, False, False], @@ -292,3 +292,21 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32): torch.testing.assert_close(c, golden, rtol=max(1e-2, 1.5 * golden_rel_err), atol=max(1e-2, 1.5 * golden_abs_err)) else: torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err)) + + +# NOTE this is useful only on Volta GPU. +@pytest.mark.parametrize('SHAPE,NUM_WARPS,TRANS_A,TRANS_B', [ + (shape, num_warps, trans_a, trans_b) + for shape in [ + [16, 16, 16], + [16, 16, 32], + [32, 16, 16], + [32, 32, 32], + [128, 16, 16], + ] + for num_warps in [1] + for trans_a in [False] + for trans_b in [False] +]) +def test_gemm_no_scf_for_mmav1(SHAPE, NUM_WARPS, TRANS_A, TRANS_B): + test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B)