diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 33c7d889f..5985536eb 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -25,6 +25,10 @@ namespace gpu { unsigned getElemsPerThread(Type type); +SmallVector getThreadsPerWarp(Attribute layout); + +SmallVector getWarpsPerCTA(Attribute layout); + SmallVector getSizePerThread(Attribute layout); SmallVector getThreadsPerCTA(const Attribute &layout); diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 8216a6317..788522ba5 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -326,7 +326,8 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> { ); let extraClassDeclaration = extraBaseClassDeclaration # [{ - SmallVector paddedShape(ArrayRef shape) const; + template + SmallVector paddedShape(ArrayRef shape) const; }]; } diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index e70dc935e..19b791df1 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -87,22 +87,22 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, SmallVector getScratchConfigForReduce(triton::ReduceOp op) { auto srcTy = op.operand().getType().cast(); - auto srcLayout = srcTy.getEncoding().cast(); + auto srcLayout = srcTy.getEncoding(); auto srcShape = srcTy.getShape(); auto axis = op.axis(); - bool fastReduce = axis == srcLayout.getOrder()[0]; + bool fastReduce = axis == getOrder(srcLayout)[0]; SmallVector smemShape; for (auto d : srcShape) smemShape.push_back(d); if (fastReduce) { - unsigned sizeInterWarps = srcLayout.getWarpsPerCTA()[axis]; + unsigned sizeInterWarps = gpu::getWarpsPerCTA(srcLayout)[axis]; smemShape[axis] = sizeInterWarps; } else { - unsigned threadsPerCTAAxis = - srcLayout.getThreadsPerWarp()[axis] * srcLayout.getWarpsPerCTA()[axis]; + unsigned threadsPerCTAAxis = gpu::getThreadsPerWarp(srcLayout)[axis] * + gpu::getWarpsPerCTA(srcLayout)[axis]; smemShape[axis] = threadsPerCTAAxis; } @@ -161,16 +161,11 @@ private: // TODO(Keren): Reduce with index is not supported yet. auto value = op->getOperand(0); if (auto tensorType = value.getType().dyn_cast()) { - if (tensorType.getEncoding().isa()) { - auto smemShape = getScratchConfigForReduce(reduceOp); - unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), - 1, std::multiplies{}); - auto bytes = elems * tensorType.getElementTypeBitWidth() / 8; - allocation->addBuffer(op, bytes); - } else { - assert(0 && "ReduceOp with input layout other than blocked layout is " - "not implemented yet"); - } + auto smemShape = getScratchConfigForReduce(reduceOp); + unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1, + std::multiplies{}); + auto bytes = elems * tensorType.getElementTypeBitWidth() / 8; + allocation->addBuffer(op, bytes); } } else if (auto cvtLayout = dyn_cast(op)) { auto srcTy = cvtLayout.src().getType().cast(); diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index b34715c99..0e21b925d 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -339,7 +339,7 @@ Value getStructFromElements(Location loc, ValueRange resultVals, } Value llvmStruct = rewriter.create(loc, structType); - for (const auto& v : llvm::enumerate(resultVals)) { + for (const auto &v : llvm::enumerate(resultVals)) { assert(v.value() && "can not insert null values"); llvmStruct = insert_val(structType, llvmStruct, v.value(), rewriter.getI64ArrayAttr(v.index())); @@ -494,6 +494,10 @@ public: rewriter.getIntegerAttr(rewriter.getIndexType(), value)); } + // ----------------------------------------------------------------------- + // Utilities + // ----------------------------------------------------------------------- + // Convert an \param index to a multi-dim coordinate given \param shape and // \param order. SmallVector delinearize(ConversionPatternRewriter &rewriter, @@ -556,6 +560,10 @@ public: return ret; } + // ----------------------------------------------------------------------- + // Blocked layout indices + // ----------------------------------------------------------------------- + // Get an index-base for each dimension for a \param blocked_layout. SmallVector emitBaseIndexForBlockedLayout(Location loc, @@ -599,52 +607,6 @@ public: return multiDimBase; } - SmallVector> emitIndices(Location loc, - ConversionPatternRewriter &b, - const Attribute &layout, - ArrayRef shape) const { - if (auto blocked = layout.dyn_cast()) { - return emitIndicesForBlockedLayout(loc, b, blocked, shape); - } else if (auto slice = layout.dyn_cast()) { - return emitIndicesForSliceLayout(loc, b, slice, shape); - } else { - assert(0 && "emitIndices for layouts other than blocked & slice not " - "implemented yet"); - return {}; - } - } - - SmallVector> - emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &rewriter, - const SliceEncodingAttr &sliceLayout, - ArrayRef shape) const { - auto parent = sliceLayout.getParent(); - unsigned dim = sliceLayout.getDim(); - size_t rank = shape.size(); - if (auto blockedParent = parent.dyn_cast()) { - auto paddedIndices = emitIndicesForBlockedLayout( - loc, rewriter, blockedParent, sliceLayout.paddedShape(shape)); - unsigned numIndices = paddedIndices.size(); - SmallVector> resultIndices(numIndices); - for (unsigned i = 0; i < numIndices; ++i) - for (unsigned d = 0; d < rank + 1; ++d) - if (d != dim) - resultIndices[i].push_back(paddedIndices[i][d]); - - return resultIndices; - - } else if (auto sliceParent = parent.dyn_cast()) { - assert(0 && "emitIndicesForSliceLayout with parent of sliceLayout" - "is not implemented yet"); - return {}; - - } else { - assert(0 && "emitIndicesForSliceLayout with parent other than blocked & " - "slice not implemented yet"); - return {}; - } - } - SmallVector> emitOffsetForBlockedLayout(const BlockedEncodingAttr &blockedLayout, ArrayRef shape) const { @@ -696,23 +658,109 @@ public: return reorderedOffset; } + // ----------------------------------------------------------------------- + // Mma layout indices + // ----------------------------------------------------------------------- + + SmallVector + emitBaseIndexForMmaLayoutV1(Location loc, ConversionPatternRewriter &rewriter, + const MmaEncodingAttr &mmaLayout, + ArrayRef shape) const { + llvm_unreachable("emitIndicesForMmaLayoutV1 not implemented"); + } + + SmallVector> + emitOffsetForMmaLayoutV1(const MmaEncodingAttr &mmaLayout, + ArrayRef shape) const { + llvm_unreachable("emitOffsetForMmaLayoutV1 not implemented"); + } + + SmallVector + emitBaseIndexForMmaLayoutV2(Location loc, ConversionPatternRewriter &rewriter, + const MmaEncodingAttr &mmaLayout, + ArrayRef shape) const { + auto _warpsPerCTA = mmaLayout.getWarpsPerCTA(); + assert(_warpsPerCTA.size() == 2); + SmallVector warpsPerCTA = {idx_val(_warpsPerCTA[0]), + idx_val(_warpsPerCTA[1])}; + Value threadId = getThreadId(rewriter, loc); + Value warpSize = idx_val(32); + Value laneId = urem(threadId, warpSize); + Value warpId = udiv(threadId, warpSize); + Value warpId0 = urem(warpId, warpsPerCTA[0]); + Value warpId1 = urem(udiv(warpId, warpsPerCTA[0]), warpsPerCTA[1]); + Value offWarp0 = mul(warpId0, idx_val(16)); + Value offWarp1 = mul(warpId1, idx_val(8)); + + SmallVector multiDimBase(2); + multiDimBase[0] = add(udiv(laneId, idx_val(4)), offWarp0); + multiDimBase[1] = add(mul(idx_val(2), urem(laneId, idx_val(4))), offWarp1); + return multiDimBase; + } + + SmallVector> + emitOffsetForMmaLayoutV2(const MmaEncodingAttr &mmaLayout, + ArrayRef shape) const { + SmallVector> ret; + for (unsigned i = 0; i < shape[0]; i += getShapePerCTA(mmaLayout)[0]) { + for (unsigned j = 0; j < shape[1]; j += getShapePerCTA(mmaLayout)[1]) { + ret.push_back({i, j}); + ret.push_back({i, j + 1}); + } + for (unsigned j = 0; j < shape[1]; j += getShapePerCTA(mmaLayout)[1]) { + ret.push_back({i + 8, j}); + ret.push_back({i + 8, j + 1}); + } + } + return ret; + } + + // ----------------------------------------------------------------------- + // Get offsets / indices for any layout + // ----------------------------------------------------------------------- + + SmallVector emitBaseIndexForLayout(Location loc, + ConversionPatternRewriter &rewriter, + const Attribute &layout, + ArrayRef shape) const { + if (auto blockedLayout = layout.dyn_cast()) + return emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape); + if (auto mmaLayout = layout.dyn_cast()) { + if (mmaLayout.getVersion() == 1) + return emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, shape); + if (mmaLayout.getVersion() == 2) + return emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, shape); + } + llvm_unreachable("unsupported emitBaseIndexForLayout"); + } + + SmallVector> + emitOffsetForLayout(const Attribute &layout, ArrayRef shape) const { + if (auto blockedLayout = layout.dyn_cast()) + return emitOffsetForBlockedLayout(blockedLayout, shape); + if (auto mmaLayout = layout.dyn_cast()) { + if (mmaLayout.getVersion() == 1) + return emitOffsetForMmaLayoutV1(mmaLayout, shape); + if (mmaLayout.getVersion() == 2) + return emitOffsetForMmaLayoutV2(mmaLayout, shape); + } + llvm_unreachable("unsupported emitOffsetForLayout"); + } + // Emit indices calculation within each ConversionPattern, and returns a // [elemsPerThread X rank] index matrix. - // TODO: [goostavz] Double confirm the redundant indices calculations will - // be eliminated in the consequent MLIR/LLVM optimization. We might - // implement an indexCache if necessary. - SmallVector> - emitIndicesForBlockedLayout(Location loc, ConversionPatternRewriter &rewriter, - const BlockedEncodingAttr &blockedLayout, - ArrayRef shape) const { + + // TODO: [phil] redundant indices commputation do not appear to hurt + // performance much, but they could still significantly slow down + // computations. + SmallVector> emitIndicesForDistributedLayout( + Location loc, ConversionPatternRewriter &rewriter, + const Attribute &layout, ArrayRef shape) const { + // step 1, delinearize threadId to get the base index - auto multiDimBase = - emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape); - + auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, layout, shape); // step 2, get offset of each element - SmallVector> offset = - emitOffsetForBlockedLayout(blockedLayout, shape); - + auto offset = emitOffsetForLayout(layout, shape); // step 3, add offset to base, and reorder the sequence of indices to // guarantee that elems in the same sizePerThread are adjacent in order unsigned rank = shape.size(); @@ -726,6 +774,49 @@ public: return multiDimIdx; } + SmallVector> + emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &rewriter, + const SliceEncodingAttr &sliceLayout, + ArrayRef shape) const { + auto parent = sliceLayout.getParent(); + unsigned dim = sliceLayout.getDim(); + size_t rank = shape.size(); + auto paddedIndices = + emitIndices(loc, rewriter, parent, sliceLayout.paddedShape(shape)); + unsigned numIndices = paddedIndices.size(); + SmallVector> resultIndices(numIndices); + for (unsigned i = 0; i < numIndices; ++i) + for (unsigned d = 0; d < rank + 1; ++d) + if (d != dim) + resultIndices[i].push_back(paddedIndices[i][d]); + + return resultIndices; + } + + // ----------------------------------------------------------------------- + // Emit indices + // ----------------------------------------------------------------------- + SmallVector> emitIndices(Location loc, + ConversionPatternRewriter &b, + const Attribute &layout, + ArrayRef shape) const { + if (auto blocked = layout.dyn_cast()) { + return emitIndicesForDistributedLayout(loc, b, blocked, shape); + } else if (auto mma = layout.dyn_cast()) { + return emitIndicesForDistributedLayout(loc, b, mma, shape); + } else if (auto slice = layout.dyn_cast()) { + return emitIndicesForSliceLayout(loc, b, slice, shape); + } else { + assert(0 && "emitIndices for layouts other than blocked & slice not " + "implemented yet"); + return {}; + } + } + + // ----------------------------------------------------------------------- + // Shared memory utilities + // ----------------------------------------------------------------------- + template Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter, T value) const { @@ -954,8 +1045,8 @@ struct LoadOpConversion // Determine the vectorization size Type valueTy = op.getResult().getType(); - Type valueElemTy = typeConverter->convertType( - getElementTypeOrSelf(valueTy)); + Type valueElemTy = + typeConverter->convertType(getElementTypeOrSelf(valueTy)); unsigned vec = getVectorSize(ptr); unsigned numElems = getElemsPerThread(ptr.getType()); if (llMask) @@ -1086,7 +1177,8 @@ struct LoadOpConversion : retTys[0]; // TODO: if (has_l2_evict_policy) - // auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(), + // auto asmDialectAttr = + // LLVM::AsmDialectAttr::get(rewriter.getContext(), // LLVM::AsmDialect::AD_ATT); Value ret = ptxBuilder.launch(rewriter, loc, retTy); @@ -1149,8 +1241,8 @@ struct StoreOpConversion MLIRContext *ctx = rewriter.getContext(); auto valueTy = value.getType(); - Type valueElemTy = typeConverter->convertType( - getElementTypeOrSelf(valueTy)); + Type valueElemTy = + typeConverter->convertType(getElementTypeOrSelf(valueTy)); unsigned vec = getVectorSize(ptr); unsigned numElems = getElemsPerThread(ptr.getType()); @@ -1571,19 +1663,19 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast( unsigned axis = adaptor.axis(); auto srcTy = op.operand().getType().cast(); - auto srcLayout = srcTy.getEncoding().cast(); + auto srcLayout = srcTy.getEncoding(); auto srcShape = srcTy.getShape(); auto srcRank = srcTy.getRank(); - auto threadsPerWarp = srcLayout.getThreadsPerWarp(); - auto warpsPerCTA = srcLayout.getWarpsPerCTA(); + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcLayout); auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); smemBase = bitcast(smemBase, elemPtrTy); - auto order = srcLayout.getOrder(); + auto order = getOrder(srcLayout); unsigned sizeIntraWarps = threadsPerWarp[axis]; unsigned sizeInterWarps = warpsPerCTA[axis]; @@ -1592,7 +1684,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast( auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter); SmallVector> offset = - emitOffsetForBlockedLayout(srcLayout, srcShape); + emitOffsetForLayout(srcLayout, srcShape); std::map, Value> accs; std::map, SmallVector> indices; @@ -1655,7 +1747,8 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast( // each thread needs to process: // elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads unsigned elems = product(smemShape); - unsigned numThreads = product(srcLayout.getWarpsPerCTA()) * 32; + unsigned numThreads = + product(triton::gpu::getWarpsPerCTA(srcLayout)) * 32; unsigned elemsPerThread = std::max(elems / numThreads, 1); Value readOffset = threadId; for (unsigned round = 0; round < elemsPerThread; ++round) { @@ -2104,9 +2197,10 @@ struct FpToFpOpConversion using ConvertTritonGPUOpToLLVMPattern< triton::FpToFpOp>::ConvertTritonGPUOpToLLVMPattern; - static SmallVector convertFp8x4ToFp16x4( - Location loc, ConversionPatternRewriter &rewriter, - const Value& v0, const Value& v1, const Value& v2, const Value& v3) { + static SmallVector + convertFp8x4ToFp16x4(Location loc, ConversionPatternRewriter &rewriter, + const Value &v0, const Value &v1, const Value &v2, + const Value &v3) { auto ctx = rewriter.getContext(); auto fp8x4VecTy = vec_ty(i8_ty, 4); Value fp8x4Vec = undef(fp8x4VecTy); @@ -2117,18 +2211,17 @@ struct FpToFpOpConversion fp8x4Vec = bitcast(fp8x4Vec, i32_ty); PTXBuilder builder; - auto *ptxAsm = - "{ \n" - ".reg .b32 a<2>, b<2>; \n" - "prmt.b32 a0, 0, $2, 0x5040; \n" - "prmt.b32 a1, 0, $2, 0x7060; \n" - "lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" - "lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" - "shr.b32 b0, b0, 1; \n" - "shr.b32 b1, b1, 1; \n" - "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" - "lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" - "}"; + auto *ptxAsm = "{ \n" + ".reg .b32 a<2>, b<2>; \n" + "prmt.b32 a0, 0, $2, 0x5040; \n" + "prmt.b32 a1, 0, $2, 0x7060; \n" + "lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" + "lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" + "shr.b32 b0, b0, 1; \n" + "shr.b32 b1, b1, 1; \n" + "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" + "lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" + "}"; auto &call = *builder.create(ptxAsm); auto *o0 = builder.newOperand("=r"); @@ -2141,21 +2234,20 @@ struct FpToFpOpConversion struct_ty(SmallVector{fp16x2VecTy, fp16x2VecTy}); auto fp16x2x2Struct = builder.launch(rewriter, loc, fp16x2x2StructTy, false); - auto fp16x2Vec0 = extract_val(fp16x2VecTy, fp16x2x2Struct, - rewriter.getI32ArrayAttr({0})); - auto fp16x2Vec1 = extract_val(fp16x2VecTy, fp16x2x2Struct, - rewriter.getI32ArrayAttr({1})); - return { - extract_element(f16_ty, fp16x2Vec0, i32_val(0)), - extract_element(f16_ty, fp16x2Vec0, i32_val(1)), - extract_element(f16_ty, fp16x2Vec1, i32_val(0)), - extract_element(f16_ty, fp16x2Vec1, i32_val(1)) - }; + auto fp16x2Vec0 = + extract_val(fp16x2VecTy, fp16x2x2Struct, rewriter.getI32ArrayAttr({0})); + auto fp16x2Vec1 = + extract_val(fp16x2VecTy, fp16x2x2Struct, rewriter.getI32ArrayAttr({1})); + return {extract_element(f16_ty, fp16x2Vec0, i32_val(0)), + extract_element(f16_ty, fp16x2Vec0, i32_val(1)), + extract_element(f16_ty, fp16x2Vec1, i32_val(0)), + extract_element(f16_ty, fp16x2Vec1, i32_val(1))}; } - static SmallVector convertFp16x4ToFp8x4( - Location loc, ConversionPatternRewriter &rewriter, - const Value& v0, const Value& v1, const Value& v2, const Value& v3) { + static SmallVector + convertFp16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter, + const Value &v0, const Value &v1, const Value &v2, + const Value &v3) { auto ctx = rewriter.getContext(); auto fp16x2VecTy = vec_ty(f16_ty, 2); Value fp16x2Vec0 = undef(fp16x2VecTy); @@ -2168,19 +2260,18 @@ struct FpToFpOpConversion fp16x2Vec1 = bitcast(fp16x2Vec1, i32_ty); PTXBuilder builder; - auto *ptxAsm = - "{ \n" - ".reg .b32 a<2>, b<2>; \n" - "shl.b32 a0, $1, 1; \n" - "shl.b32 a1, $2, 1; \n" - "lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n" - "lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n" - "add.u32 a0, a0, 0x00800080; \n" - "add.u32 a1, a1, 0x00800080; \n" - "lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n" - "lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n" - "prmt.b32 $0, b0, b1, 0x7531; \n" - "}"; + auto *ptxAsm = "{ \n" + ".reg .b32 a<2>, b<2>; \n" + "shl.b32 a0, $1, 1; \n" + "shl.b32 a1, $2, 1; \n" + "lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n" + "lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n" + "add.u32 a0, a0, 0x00800080; \n" + "add.u32 a1, a1, 0x00800080; \n" + "lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n" + "lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n" + "prmt.b32 $0, b0, b1, 0x7531; \n" + "}"; auto &call = *builder.create(ptxAsm); auto *o = builder.newOperand("=r"); @@ -2190,17 +2281,16 @@ struct FpToFpOpConversion auto fp8x4VecTy = vec_ty(i8_ty, 4); auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false); - return { - extract_element(i8_ty, fp8x4Vec, i32_val(0)), - extract_element(i8_ty, fp8x4Vec, i32_val(1)), - extract_element(i8_ty, fp8x4Vec, i32_val(2)), - extract_element(i8_ty, fp8x4Vec, i32_val(3)) - }; + return {extract_element(i8_ty, fp8x4Vec, i32_val(0)), + extract_element(i8_ty, fp8x4Vec, i32_val(1)), + extract_element(i8_ty, fp8x4Vec, i32_val(2)), + extract_element(i8_ty, fp8x4Vec, i32_val(3))}; } - static SmallVector convertFp8x4ToBf16x4( - Location loc, ConversionPatternRewriter &rewriter, - const Value& v0, const Value& v1, const Value& v2, const Value& v3) { + static SmallVector + convertFp8x4ToBf16x4(Location loc, ConversionPatternRewriter &rewriter, + const Value &v0, const Value &v1, const Value &v2, + const Value &v3) { auto ctx = rewriter.getContext(); auto fp8x4VecTy = vec_ty(i8_ty, 4); Value fp8x4Vec = undef(fp8x4VecTy); @@ -2211,22 +2301,21 @@ struct FpToFpOpConversion fp8x4Vec = bitcast(fp8x4Vec, i32_ty); PTXBuilder builder; - auto *ptxAsm = - "{ \n" - ".reg .b32 a<2>, sign<2>, nosign<2>, b<2>; \n" - "prmt.b32 a0, 0, $2, 0x5040; \n" - "prmt.b32 a1, 0, $2, 0x7060; \n" - "and.b32 sign0, a0, 0x80008000; \n" - "and.b32 sign1, a1, 0x80008000; \n" - "and.b32 nosign0, a0, 0x7fff7fff; \n" - "and.b32 nosign1, a1, 0x7fff7fff; \n" - "shr.b32 nosign0, nosign0, 4; \n" - "shr.b32 nosign1, nosign1, 4; \n" - "add.u32 nosign0, nosign0, 0x38003800; \n" - "add.u32 nosign1, nosign1, 0x38003800; \n" - "or.b32 $0, sign0, nosign0; \n" - "or.b32 $1, sign1, nosign1; \n" - "}"; + auto *ptxAsm = "{ \n" + ".reg .b32 a<2>, sign<2>, nosign<2>, b<2>; \n" + "prmt.b32 a0, 0, $2, 0x5040; \n" + "prmt.b32 a1, 0, $2, 0x7060; \n" + "and.b32 sign0, a0, 0x80008000; \n" + "and.b32 sign1, a1, 0x80008000; \n" + "and.b32 nosign0, a0, 0x7fff7fff; \n" + "and.b32 nosign1, a1, 0x7fff7fff; \n" + "shr.b32 nosign0, nosign0, 4; \n" + "shr.b32 nosign1, nosign1, 4; \n" + "add.u32 nosign0, nosign0, 0x38003800; \n" + "add.u32 nosign1, nosign1, 0x38003800; \n" + "or.b32 $0, sign0, nosign0; \n" + "or.b32 $1, sign1, nosign1; \n" + "}"; auto &call = *builder.create(ptxAsm); auto *o0 = builder.newOperand("=r"); @@ -2239,21 +2328,20 @@ struct FpToFpOpConversion struct_ty(SmallVector{bf16x2VecTy, bf16x2VecTy}); auto bf16x2x2Struct = builder.launch(rewriter, loc, bf16x2x2StructTy, false); - auto bf16x2Vec0 = extract_val(bf16x2VecTy, bf16x2x2Struct, - rewriter.getI32ArrayAttr({0})); - auto bf16x2Vec1 = extract_val(bf16x2VecTy, bf16x2x2Struct, - rewriter.getI32ArrayAttr({1})); - return { - extract_element(bf16_ty, bf16x2Vec0, i32_val(0)), - extract_element(bf16_ty, bf16x2Vec0, i32_val(1)), - extract_element(bf16_ty, bf16x2Vec1, i32_val(0)), - extract_element(bf16_ty, bf16x2Vec1, i32_val(1)) - }; + auto bf16x2Vec0 = + extract_val(bf16x2VecTy, bf16x2x2Struct, rewriter.getI32ArrayAttr({0})); + auto bf16x2Vec1 = + extract_val(bf16x2VecTy, bf16x2x2Struct, rewriter.getI32ArrayAttr({1})); + return {extract_element(bf16_ty, bf16x2Vec0, i32_val(0)), + extract_element(bf16_ty, bf16x2Vec0, i32_val(1)), + extract_element(bf16_ty, bf16x2Vec1, i32_val(0)), + extract_element(bf16_ty, bf16x2Vec1, i32_val(1))}; } - static SmallVector convertBf16x4ToFp8x4( - Location loc, ConversionPatternRewriter &rewriter, - const Value& v0, const Value& v1, const Value& v2, const Value& v3) { + static SmallVector + convertBf16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter, + const Value &v0, const Value &v1, const Value &v2, + const Value &v3) { auto ctx = rewriter.getContext(); auto bf16x2VecTy = vec_ty(bf16_ty, 2); Value bf16x2Vec0 = undef(bf16x2VecTy); @@ -2266,43 +2354,42 @@ struct FpToFpOpConversion bf16x2Vec1 = bitcast(bf16x2Vec1, i32_ty); PTXBuilder builder; - auto *ptxAsm = - "{ \n" - ".reg .u32 sign, sign<2>, nosign, nosign<2>; \n" - ".reg .u32 fp8_min, fp8_max, rn_, zero; \n" - "mov.u32 fp8_min, 0x38003800; \n" - "mov.u32 fp8_max, 0x3ff03ff0; \n" - "mov.u32 rn_, 0x80008; \n" - "mov.u32 zero, 0; \n" - "and.b32 sign0, $1, 0x80008000; \n" - "and.b32 sign1, $2, 0x80008000; \n" - "prmt.b32 sign, sign0, sign1, 0x7531; \n" - "and.b32 nosign0, $1, 0x7fff7fff; \n" - "and.b32 nosign1, $2, 0x7fff7fff; \n" - ".reg .u32 nosign_0_<2>, nosign_1_<2>; \n" - "and.b32 nosign_0_0, nosign0, 0xffff0000; \n" - "max.u32 nosign_0_0, nosign_0_0, 0x38000000; \n" - "min.u32 nosign_0_0, nosign_0_0, 0x3ff00000; \n" - "and.b32 nosign_0_1, nosign0, 0x0000ffff; \n" - "max.u32 nosign_0_1, nosign_0_1, 0x3800; \n" - "min.u32 nosign_0_1, nosign_0_1, 0x3ff0; \n" - "or.b32 nosign0, nosign_0_0, nosign_0_1; \n" - "and.b32 nosign_1_0, nosign1, 0xffff0000; \n" - "max.u32 nosign_1_0, nosign_1_0, 0x38000000; \n" - "min.u32 nosign_1_0, nosign_1_0, 0x3ff00000; \n" - "and.b32 nosign_1_1, nosign1, 0x0000ffff; \n" - "max.u32 nosign_1_1, nosign_1_1, 0x3800; \n" - "min.u32 nosign_1_1, nosign_1_1, 0x3ff0; \n" - "or.b32 nosign1, nosign_1_0, nosign_1_1; \n" - "add.u32 nosign0, nosign0, rn_; \n" - "add.u32 nosign1, nosign1, rn_; \n" - "sub.u32 nosign0, nosign0, 0x38003800; \n" - "sub.u32 nosign1, nosign1, 0x38003800; \n" - "shr.u32 nosign0, nosign0, 4; \n" - "shr.u32 nosign1, nosign1, 4; \n" - "prmt.b32 nosign, nosign0, nosign1, 0x6420; \n" - "or.b32 $0, nosign, sign; \n" - "}"; + auto *ptxAsm = "{ \n" + ".reg .u32 sign, sign<2>, nosign, nosign<2>; \n" + ".reg .u32 fp8_min, fp8_max, rn_, zero; \n" + "mov.u32 fp8_min, 0x38003800; \n" + "mov.u32 fp8_max, 0x3ff03ff0; \n" + "mov.u32 rn_, 0x80008; \n" + "mov.u32 zero, 0; \n" + "and.b32 sign0, $1, 0x80008000; \n" + "and.b32 sign1, $2, 0x80008000; \n" + "prmt.b32 sign, sign0, sign1, 0x7531; \n" + "and.b32 nosign0, $1, 0x7fff7fff; \n" + "and.b32 nosign1, $2, 0x7fff7fff; \n" + ".reg .u32 nosign_0_<2>, nosign_1_<2>; \n" + "and.b32 nosign_0_0, nosign0, 0xffff0000; \n" + "max.u32 nosign_0_0, nosign_0_0, 0x38000000; \n" + "min.u32 nosign_0_0, nosign_0_0, 0x3ff00000; \n" + "and.b32 nosign_0_1, nosign0, 0x0000ffff; \n" + "max.u32 nosign_0_1, nosign_0_1, 0x3800; \n" + "min.u32 nosign_0_1, nosign_0_1, 0x3ff0; \n" + "or.b32 nosign0, nosign_0_0, nosign_0_1; \n" + "and.b32 nosign_1_0, nosign1, 0xffff0000; \n" + "max.u32 nosign_1_0, nosign_1_0, 0x38000000; \n" + "min.u32 nosign_1_0, nosign_1_0, 0x3ff00000; \n" + "and.b32 nosign_1_1, nosign1, 0x0000ffff; \n" + "max.u32 nosign_1_1, nosign_1_1, 0x3800; \n" + "min.u32 nosign_1_1, nosign_1_1, 0x3ff0; \n" + "or.b32 nosign1, nosign_1_0, nosign_1_1; \n" + "add.u32 nosign0, nosign0, rn_; \n" + "add.u32 nosign1, nosign1, rn_; \n" + "sub.u32 nosign0, nosign0, 0x38003800; \n" + "sub.u32 nosign1, nosign1, 0x38003800; \n" + "shr.u32 nosign0, nosign0, 4; \n" + "shr.u32 nosign1, nosign1, 4; \n" + "prmt.b32 nosign, nosign0, nosign1, 0x6420; \n" + "or.b32 $0, nosign, sign; \n" + "}"; auto &call = *builder.create(ptxAsm); auto *o = builder.newOperand("=r"); @@ -2312,51 +2399,49 @@ struct FpToFpOpConversion auto fp8x4VecTy = vec_ty(i8_ty, 4); auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false); - return { - extract_element(i8_ty, fp8x4Vec, i32_val(0)), - extract_element(i8_ty, fp8x4Vec, i32_val(1)), - extract_element(i8_ty, fp8x4Vec, i32_val(2)), - extract_element(i8_ty, fp8x4Vec, i32_val(3)) - }; + return {extract_element(i8_ty, fp8x4Vec, i32_val(0)), + extract_element(i8_ty, fp8x4Vec, i32_val(1)), + extract_element(i8_ty, fp8x4Vec, i32_val(2)), + extract_element(i8_ty, fp8x4Vec, i32_val(3))}; } - static SmallVector convertFp8x4ToFp32x4( - Location loc, ConversionPatternRewriter &rewriter, - const Value& v0, const Value& v1, const Value& v2, const Value& v3) { + static SmallVector + convertFp8x4ToFp32x4(Location loc, ConversionPatternRewriter &rewriter, + const Value &v0, const Value &v1, const Value &v2, + const Value &v3) { auto fp16Values = convertFp8x4ToFp16x4(loc, rewriter, v0, v1, v2, v3); - return { - rewriter.create(loc, f32_ty, fp16Values[0]), - rewriter.create(loc, f32_ty, fp16Values[1]), - rewriter.create(loc, f32_ty, fp16Values[2]), - rewriter.create(loc, f32_ty, fp16Values[3]) - }; + return {rewriter.create(loc, f32_ty, fp16Values[0]), + rewriter.create(loc, f32_ty, fp16Values[1]), + rewriter.create(loc, f32_ty, fp16Values[2]), + rewriter.create(loc, f32_ty, fp16Values[3])}; } - static SmallVector convertFp32x4ToFp8x4( - Location loc, ConversionPatternRewriter &rewriter, - const Value& v0, const Value& v1, const Value& v2, const Value& v3) { - auto c0 = rewriter.create(loc, f16_ty, v0); - auto c1 = rewriter.create(loc, f16_ty, v1); - auto c2 = rewriter.create(loc, f16_ty, v2); - auto c3 = rewriter.create(loc, f16_ty, v3); - return convertFp16x4ToFp8x4(loc, rewriter, c0, c1, c2, c3); + static SmallVector + convertFp32x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter, + const Value &v0, const Value &v1, const Value &v2, + const Value &v3) { + auto c0 = rewriter.create(loc, f16_ty, v0); + auto c1 = rewriter.create(loc, f16_ty, v1); + auto c2 = rewriter.create(loc, f16_ty, v2); + auto c3 = rewriter.create(loc, f16_ty, v3); + return convertFp16x4ToFp8x4(loc, rewriter, c0, c1, c2, c3); } - static SmallVector convertFp8x4ToFp64x4( - Location loc, ConversionPatternRewriter &rewriter, - const Value& v0, const Value& v1, const Value& v2, const Value& v3) { + static SmallVector + convertFp8x4ToFp64x4(Location loc, ConversionPatternRewriter &rewriter, + const Value &v0, const Value &v1, const Value &v2, + const Value &v3) { auto fp16Values = convertFp8x4ToFp16x4(loc, rewriter, v0, v1, v2, v3); - return { - rewriter.create(loc, f64_ty, fp16Values[0]), - rewriter.create(loc, f64_ty, fp16Values[1]), - rewriter.create(loc, f64_ty, fp16Values[2]), - rewriter.create(loc, f64_ty, fp16Values[3]) - }; + return {rewriter.create(loc, f64_ty, fp16Values[0]), + rewriter.create(loc, f64_ty, fp16Values[1]), + rewriter.create(loc, f64_ty, fp16Values[2]), + rewriter.create(loc, f64_ty, fp16Values[3])}; } - static SmallVector convertFp64x4ToFp8x4( - Location loc, ConversionPatternRewriter &rewriter, - const Value& v0, const Value& v1, const Value& v2, const Value& v3) { + static SmallVector + convertFp64x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter, + const Value &v0, const Value &v1, const Value &v2, + const Value &v3) { auto c0 = rewriter.create(loc, f16_ty, v0); auto c1 = rewriter.create(loc, f16_ty, v1); auto c2 = rewriter.create(loc, f16_ty, v2); @@ -2379,9 +2464,10 @@ struct FpToFpOpConversion this->getTypeConverter()->convertType(dstEltType); // Select convertor - std::function(Location, ConversionPatternRewriter&, - const Value&, const Value&, - const Value&, const Value&)> convertor; + std::function(Location, ConversionPatternRewriter &, + const Value &, const Value &, + const Value &, const Value &)> + convertor; if (srcEltType.isa() && dstEltType.isF16()) { convertor = convertFp8x4ToFp16x4; } else if (srcEltType.isF16() && dstEltType.isa()) { @@ -2410,8 +2496,7 @@ struct FpToFpOpConversion auto elements = getElementsFromStruct(loc, adaptor.from(), rewriter); SmallVector resultVals; for (size_t i = 0; i < elems; i += 4) { - auto converted = convertor(loc, rewriter, - elements[i], elements[i + 1], + auto converted = convertor(loc, rewriter, elements[i], elements[i + 1], elements[i + 2], elements[i + 3]); resultVals.append(converted); } @@ -2626,6 +2711,82 @@ public: } private: + SmallVector getMultiDimOffset(Attribute layout, Location loc, + ConversionPatternRewriter &rewriter, + unsigned elemId, ArrayRef shape, + ArrayRef multiDimCTAInRepId, + ArrayRef shapePerCTA) const { + unsigned rank = shape.size(); + if (auto blockedLayout = layout.dyn_cast()) { + auto multiDimOffsetFirstElem = + emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape); + SmallVector multiDimOffset(rank); + SmallVector multiDimElemId = + getMultiDimIndex(elemId, blockedLayout.getSizePerThread()); + for (unsigned d = 0; d < rank; ++d) { + multiDimOffset[d] = add(multiDimOffsetFirstElem[d], + idx_val(multiDimCTAInRepId[d] * shapePerCTA[d] + + multiDimElemId[d])); + } + return multiDimOffset; + } + if (auto sliceLayout = layout.dyn_cast()) { + unsigned dim = sliceLayout.getDim(); + auto multiDimOffsetParent = + getMultiDimOffset(sliceLayout.getParent(), loc, rewriter, elemId, + sliceLayout.paddedShape(shape), + sliceLayout.paddedShape(multiDimCTAInRepId), + sliceLayout.paddedShape(shapePerCTA)); + SmallVector multiDimOffset(rank); + for (unsigned d = 0; d < rank + 1; ++d) { + if (d == dim) + continue; + unsigned slicedD = d < dim ? d : (d - 1); + multiDimOffset[slicedD] = multiDimOffsetParent[d]; + } + return multiDimOffset; + } + if (auto mmaLayout = layout.dyn_cast()) { + SmallVector mmaColIdx(2); + SmallVector mmaRowIdx(2); + Value threadId = getThreadId(rewriter, loc); + Value warpSize = idx_val(32); + Value laneId = urem(threadId, warpSize); + Value warpId = udiv(threadId, warpSize); + // auto multiDimWarpId = + // delinearize(rewriter, loc, warpId, mmaLayout.getWarpsPerCTA()); + // TODO: double confirm if its document bug or DotConversion's Bug + SmallVector multiDimWarpId(2); + multiDimWarpId[0] = urem(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0])); + multiDimWarpId[1] = udiv(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0])); + Value four = idx_val(4); + Value mmaGrpId = udiv(laneId, four); + Value mmaGrpIdP8 = add(mmaGrpId, idx_val(8)); + Value mmaThreadIdInGrp = urem(laneId, four); + Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, idx_val(2)); + Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, idx_val(1)); + Value colWarpOffset = mul(multiDimWarpId[0], idx_val(16)); + mmaColIdx[0] = add(mmaGrpId, colWarpOffset); + mmaColIdx[1] = add(mmaGrpIdP8, colWarpOffset); + Value rowWarpOffset = mul(multiDimWarpId[1], idx_val(8)); + mmaRowIdx[0] = add(mmaThreadIdInGrpM2, rowWarpOffset); + mmaRowIdx[1] = add(mmaThreadIdInGrpM2P1, rowWarpOffset); + + assert(rank == 2); + assert(mmaLayout.getVersion() == 2 && + "mmaLayout ver1 not implemented yet"); + SmallVector multiDimOffset(rank); + multiDimOffset[0] = elemId < 2 ? mmaColIdx[0] : mmaColIdx[1]; + multiDimOffset[1] = elemId % 2 == 0 ? mmaRowIdx[0] : mmaRowIdx[1]; + multiDimOffset[0] = add(multiDimOffset[0], + idx_val(multiDimCTAInRepId[0] * shapePerCTA[0])); + multiDimOffset[1] = add(multiDimOffset[1], + idx_val(multiDimCTAInRepId[1] * shapePerCTA[1])); + return multiDimOffset; + } + llvm_unreachable("unexpected layout in getMultiDimOffset"); + } + // shared memory rd/st for blocked or mma layout with data padding void processReplica(Location loc, ConversionPatternRewriter &rewriter, bool stNotRd, RankedTensorType type, @@ -2693,47 +2854,6 @@ void ConvertLayoutOpConversion::processReplica( elemTy = IntegerType::get(elemTy.getContext(), 8); auto llvmElemTy = getTypeConverter()->convertType(elemTy); - SmallVector multiDimOffsetFirstElem; - SmallVector mmaColIdx(2); - SmallVector mmaRowIdx(2); - if (blockedLayout) { - multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout( - loc, rewriter, blockedLayout, type.getShape()); - } else if (sliceLayout) { - auto parent = sliceLayout.getParent(); - if (auto blockedParent = parent.dyn_cast()) { - SmallVector paddedShape = - sliceLayout.paddedShape(type.getShape()); - multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout( - loc, rewriter, blockedParent, paddedShape); - } else { - assert(0 && "SliceEncodingAttr with parent other than " - "BlockedEncodingAttr not implemented"); - } - } else if (mmaLayout) { - Value threadId = getThreadId(rewriter, loc); - Value warpSize = idx_val(32); - Value laneId = urem(threadId, warpSize); - Value warpId = udiv(threadId, warpSize); - // auto multiDimWarpId = - // delinearize(rewriter, loc, warpId, mmaLayout.getWarpsPerCTA()); - // TODO: double confirm if its document bug or DotConversion's Bug - SmallVector multiDimWarpId(2); - multiDimWarpId[0] = urem(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0])); - multiDimWarpId[1] = udiv(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0])); - Value four = idx_val(4); - Value mmaGrpId = udiv(laneId, four); - Value mmaGrpIdP8 = add(mmaGrpId, idx_val(8)); - Value mmaThreadIdInGrp = urem(laneId, four); - Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, idx_val(2)); - Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, idx_val(1)); - Value colWarpOffset = mul(multiDimWarpId[0], idx_val(16)); - mmaColIdx[0] = add(mmaGrpId, colWarpOffset); - mmaColIdx[1] = add(mmaGrpIdP8, colWarpOffset); - Value rowWarpOffset = mul(multiDimWarpId[1], idx_val(8)); - mmaRowIdx[0] = add(mmaThreadIdInGrpM2, rowWarpOffset); - mmaRowIdx[1] = add(mmaThreadIdInGrpM2P1, rowWarpOffset); - } for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) { auto multiDimCTAInRepId = getMultiDimIndex(ctaId, numCTAsEachRep); SmallVector multiDimCTAId(rank); @@ -2747,48 +2867,9 @@ void ConvertLayoutOpConversion::processReplica( // consider of caching the index calculation result in case // of performance issue observed. for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) { - SmallVector multiDimOffset(rank); - if (blockedLayout) { - SmallVector multiDimElemId = getMultiDimIndex( - elemId, blockedLayout.getSizePerThread()); - for (unsigned d = 0; d < rank; ++d) { - multiDimOffset[d] = - add(multiDimOffsetFirstElem[d], - idx_val(multiDimCTAInRepId[d] * shapePerCTA[d] + - multiDimElemId[d])); - } - } else if (sliceLayout) { - unsigned dim = sliceLayout.getDim(); - auto parent = sliceLayout.getParent(); - if (auto blockedParent = parent.dyn_cast()) { - SmallVector multiDimElemId = getMultiDimIndex( - elemId, blockedParent.getSizePerThread()); - for (unsigned d = 0; d < rank + 1; ++d) { - if (d == dim) - continue; - unsigned slicedD = d < dim ? d : (d - 1); - multiDimOffset[slicedD] = - add(multiDimOffsetFirstElem[d], - idx_val(multiDimCTAInRepId[slicedD] * shapePerCTA[slicedD] + - multiDimElemId[d])); - } - } else { - assert(0 && "SliceEncodingAttr with parent other than " - "BlockedEncodingAttr not implemented"); - } - } else if (mmaLayout) { - assert(rank == 2); - assert(mmaLayout.getVersion() == 2 && - "mmaLayout ver1 not implemented yet"); - multiDimOffset[0] = elemId < 2 ? mmaColIdx[0] : mmaColIdx[1]; - multiDimOffset[1] = elemId % 2 == 0 ? mmaRowIdx[0] : mmaRowIdx[1]; - multiDimOffset[0] = add( - multiDimOffset[0], idx_val(multiDimCTAInRepId[0] * shapePerCTA[0])); - multiDimOffset[1] = add( - multiDimOffset[1], idx_val(multiDimCTAInRepId[1] * shapePerCTA[1])); - } else { - assert(0 && "unexpected layout in processReplica"); - } + SmallVector multiDimOffset = + getMultiDimOffset(layout, loc, rewriter, elemId, type.getShape(), + multiDimCTAInRepId, shapePerCTA); Value offset = linearize(rewriter, loc, reorder(multiDimOffset, outOrd), reorder(paddedRepShape, outOrd)); @@ -3637,10 +3718,10 @@ struct DotOpMmaV1ConversionHelper { // Compute the offset of the matrix to load. // Returns offsetAM, offsetAK, offsetBN, offsetBK. // NOTE, the information M(from $a) and N(from $b) couldn't be retrieved at - // the same time in the usage in convert_layout[shared->dot_op], we leave the - // noexist info to be 0 and only use the desired argument from the composed - // result. In this way we want to retain the original code structure in - // convert_mma884 method for easier debugging. + // the same time in the usage in convert_layout[shared->dot_op], we leave + // the noexist info to be 0 and only use the desired argument from the + // composed result. In this way we want to retain the original code + // structure in convert_mma884 method for easier debugging. std::tuple computeOffsets(Value threadId, bool isARow, bool isBRow, ArrayRef fpw, ArrayRef spw, ArrayRef rep, @@ -5231,8 +5312,8 @@ struct InsertSliceAsyncOpConversion srcStrides.emplace_back(smemObj.strides[i]); } } - // Compute the offset based on the original dimensions of the shared memory - // object + // Compute the offset based on the original dimensions of the shared + // memory object auto dstOffset = dot(rewriter, loc, offsetVals, smemObj.strides); auto dstPtrTy = ptr_ty(getTypeConverter()->convertType(resTy.getElementType()), 3); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index e91ed6d69..6c7cd86dc 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -42,15 +42,11 @@ static Type getPointeeType(Type type) { namespace gpu { -// TODO: Inheritation of layout attributes -unsigned getElemsPerThread(Type type) { - if (type.isIntOrIndexOrFloat() || - type.isa() || - type.isa()) - return 1; - auto tensorType = type.cast(); - auto layout = tensorType.getEncoding(); - auto shape = tensorType.getShape(); +// TODO: Inheritance of layout attributes +// so that all distributed layouts implement +// these utilities + +unsigned getElemsPerThread(Attribute layout, ArrayRef shape) { if (auto blockedLayout = layout.dyn_cast()) { return blockedLayout.getElemsPerThread(shape); } else if (auto sliceLayout = layout.dyn_cast()) { @@ -67,6 +63,43 @@ unsigned getElemsPerThread(Type type) { } } +unsigned getElemsPerThread(Type type) { + if (type.isIntOrIndexOrFloat() || + type.isa() || + type.isa()) + return 1; + auto tensorType = type.cast(); + return getElemsPerThread(tensorType.getEncoding(), tensorType.getShape()); +} + +SmallVector getThreadsPerWarp(Attribute layout) { + if (auto blockedLayout = layout.dyn_cast()) { + return SmallVector(blockedLayout.getThreadsPerWarp().begin(), + blockedLayout.getThreadsPerWarp().end()); + } + if (auto mmaLayout = layout.dyn_cast()) { + if (mmaLayout.getVersion() == 1) + return SmallVector{4, 8}; + if (mmaLayout.getVersion() == 2) + return SmallVector{8, 4}; + } + assert(0 && "getThreadsPerWarp not implemented"); + return {}; +} + +SmallVector getWarpsPerCTA(Attribute layout) { + if (auto blockedLayout = layout.dyn_cast()) { + return SmallVector(blockedLayout.getWarpsPerCTA().begin(), + blockedLayout.getWarpsPerCTA().end()); + } + if (auto mmaLayout = layout.dyn_cast()) { + return SmallVector(mmaLayout.getWarpsPerCTA().begin(), + mmaLayout.getWarpsPerCTA().end()); + } + assert(0 && "getWarpsPerCTA not implemented"); + return {}; +} + SmallVector getSizePerThread(Attribute layout) { if (auto blockedLayout = layout.dyn_cast()) { return SmallVector(blockedLayout.getSizePerThread().begin(), @@ -129,17 +162,11 @@ SmallVector getShapePerCTA(const Attribute &layout) { } else if (auto sliceLayout = layout.dyn_cast()) { unsigned dim = sliceLayout.getDim(); auto parent = sliceLayout.getParent(); - if (auto blockedParent = parent.dyn_cast()) { - for (unsigned d = 0, n = blockedParent.getOrder().size(); d < n; ++d) { - if (d == dim) - continue; - shape.push_back(blockedParent.getSizePerThread()[d] * - blockedParent.getThreadsPerWarp()[d] * - blockedParent.getWarpsPerCTA()[d]); - } - } else { - assert(0 && "SliceEncodingAttr with parent other than " - "BlockedEncodingAttr not implemented"); + for (unsigned d = 0, n = getOrder(parent).size(); d < n; ++d) { + if (d == dim) + continue; + shape.push_back(getSizePerThread(parent)[d] * + getThreadsPerWarp(parent)[d] * getWarpsPerCTA(parent)[d]); } } else if (auto mmaLayout = layout.dyn_cast()) { if (mmaLayout.getVersion() == 2) @@ -289,11 +316,11 @@ unsigned BlockedEncodingAttr::getElemsPerThread(ArrayRef shape) const { return product(elemsPerThread); } -SmallVector -SliceEncodingAttr::paddedShape(ArrayRef shape) const { +template +SmallVector SliceEncodingAttr::paddedShape(ArrayRef shape) const { size_t rank = shape.size(); unsigned dim = getDim(); - SmallVector retShape(rank + 1); + SmallVector retShape(rank + 1); for (unsigned d = 0; d < rank + 1; ++d) { if (d < dim) retShape[d] = shape[d]; @@ -304,18 +331,15 @@ SliceEncodingAttr::paddedShape(ArrayRef shape) const { } return retShape; } +template SmallVector +SliceEncodingAttr::paddedShape(ArrayRef shape) const; +template SmallVector +SliceEncodingAttr::paddedShape(ArrayRef shape) const; unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef shape) const { size_t rank = shape.size(); auto parent = getParent(); - if (auto blockedParent = parent.dyn_cast()) { - assert(rank == blockedParent.getSizePerThread().size() - 1 && - "unexpected rank in SliceEncodingAttr::getElemsPerThread"); - return blockedParent.getElemsPerThread(paddedShape(shape)); - } else { - assert(0 && "getElemsPerThread not implemented"); - return 0; - } + return ::getElemsPerThread(parent, paddedShape(shape)); } unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef shape) const { diff --git a/python/tests/test_reduce.py b/python/tests/test_reduce.py index 5e9a7445a..dd8b4ecba 100644 --- a/python/tests/test_reduce.py +++ b/python/tests/test_reduce.py @@ -117,8 +117,7 @@ def test_reduce2d(op, dtype, shape, axis): z = torch.empty(reduced_shape, device=x.device, dtype=reduced_dtype) kernel = patch_kernel(reduce2d_kernel, {'OP': op}) - grid = (1,) - kernel[grid](x_ptr=x, z_ptr=z, axis=axis, block_m=shape[0], block_n=shape[1]) + kernel[(1,)](x_ptr=x, z_ptr=z, axis=axis, block_m=shape[0], block_n=shape[1]) if op == 'sum': golden_z = torch.sum(x, dim=axis, keepdim=False, dtype=reduced_dtype)