From ac0f6793ccd28dcf7b8b348f808a00454ab08f37 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 28 Oct 2022 01:17:55 -0700 Subject: [PATCH] [BACKEND] Added support for scalars in LoadOp / StoreOp / ElementwiseOp (#814) Also fixed various errors that showed up in `test_core.py`, and added more TODOs for open (hopefully relatively minor) issues --- include/triton/Dialect/TritonGPU/IR/Dialect.h | 2 +- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 377 +++++------------- lib/Dialect/Triton/IR/Ops.cpp | 17 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 7 +- python/tests/test_core.py | 279 ++++++------- test/Conversion/tritongpu_to_llvm.mlir | 6 +- 6 files changed, 269 insertions(+), 419 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 13b488052..1d3312637 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -21,7 +21,7 @@ namespace mlir { namespace triton { namespace gpu { -unsigned getElemsPerThread(Attribute layout, ArrayRef shape); +unsigned getElemsPerThread(Type type); SmallVector getSizePerThread(Attribute layout); diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 4949c1a87..6d464f2a0 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" #include "triton/Analysis/Allocation.h" #include "triton/Analysis/AxisInfo.h" @@ -310,6 +311,10 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> { Value getStructFromElements(Location loc, ValueRange resultVals, ConversionPatternRewriter &rewriter, Type structType) { + if (!structType.isa()) { + return *resultVals.begin(); + } + Value llvmStruct = rewriter.create(loc, structType); for (auto v : llvm::enumerate(resultVals)) { llvmStruct = insert_val(structType, llvmStruct, v.value(), @@ -369,6 +374,10 @@ struct ConvertTritonGPUOpToLLVMPatternBase { static SmallVector getElementsFromStruct(Location loc, Value llvmStruct, ConversionPatternRewriter &rewriter) { + if (llvmStruct.getType().isIntOrIndexOrFloat() || + llvmStruct.getType().isa() || + llvmStruct.getType().isa()) + return {llvmStruct}; ArrayRef types = llvmStruct.getType().cast().getBody(); SmallVector results(types.size()); @@ -678,7 +687,7 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal, auto layout = tensorTy.getEncoding(); auto srcType = typeConverter->convertType(elemType); auto llSrc = bitcast(srcType, constVal); - size_t elemsPerThread = getElemsPerThread(layout, tensorTy.getShape()); + size_t elemsPerThread = getElemsPerThread(tensorTy); llvm::SmallVector elems(elemsPerThread, llSrc); llvm::SmallVector elemTypes(elems.size(), srcType); auto structTy = @@ -760,64 +769,49 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase { // Get corresponding LLVM element values of \param value. SmallVector getLLVMElems(Value value, Value llValue, - const BlockedEncodingAttr &layout, ConversionPatternRewriter &rewriter, Location loc) const { if (!value) return {}; - - auto shape = value.getType().cast().getShape(); + if (!llValue.getType().isa()) + return {llValue}; // Here, we assume that all inputs should have a blockedLayout auto valueVals = getElementsFromStruct(loc, llValue, rewriter); return valueVals; } - // Get the blocked layout. - std::tuple getLayout(Value val) const { - auto ty = val.getType().cast(); - // Here, we assume that all inputs should have a blockedLayout - auto layout = ty.getEncoding().dyn_cast(); - assert(layout && "unexpected layout in getLayout"); - auto shape = ty.getShape(); - unsigned valueElems = layout.getElemsPerThread(shape); - return {layout, valueElems}; - } + unsigned getVectorSize(Value ptr) const { + auto tensorTy = ptr.getType().dyn_cast(); + if (!tensorTy) + return 1; + auto layout = tensorTy.getEncoding(); + auto shape = tensorTy.getShape(); - unsigned getAlignment(Value val, const BlockedEncodingAttr &layout) const { - auto axisInfo = getAxisInfo(val); - auto order = layout.getOrder(); - unsigned maxMultiple = axisInfo->getDivisibility(order[0]); - unsigned maxContig = axisInfo->getContiguity(order[0]); - unsigned alignment = std::min(maxMultiple, maxContig); - return alignment; - } - - unsigned getVectorizeSize(Value ptr, - const BlockedEncodingAttr &layout) const { auto axisInfo = getAxisInfo(ptr); // Here order should be ordered by contiguous first, so the first element // should have the largest contiguous. - auto order = layout.getOrder(); + auto order = getOrder(layout); unsigned align = getAlignment(ptr, layout); - auto ty = ptr.getType().dyn_cast(); - assert(ty); - auto shape = ty.getShape(); - - unsigned contigPerThread = layout.getSizePerThread()[order[0]]; + unsigned contigPerThread = getSizePerThread(layout)[order[0]]; unsigned vec = std::min(align, contigPerThread); vec = std::min(shape[order[0]], vec); return vec; } - unsigned getMaskAlignment(Value mask) const { - auto maskOrder = mask.getType() - .cast() - .getEncoding() - .cast() - .getOrder(); + unsigned getAlignment(Value val, const Attribute &layout) const { + auto axisInfo = getAxisInfo(val); + auto order = getOrder(layout); + unsigned maxMultiple = axisInfo->getDivisibility(order[0]); + unsigned maxContig = axisInfo->getContiguity(order[0]); + unsigned alignment = std::min(maxMultiple, maxContig); + return alignment; + } + unsigned getMaskAlignment(Value mask) const { + auto tensorTy = mask.getType().cast(); + auto maskOrder = getOrder(tensorTy.getEncoding()); auto maskAxis = getAxisInfo(mask); return std::max(maskAxis->getConstancy(maskOrder[0]), 1); } @@ -848,46 +842,39 @@ struct LoadOpConversion LogicalResult matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + MLIRContext *ctx = rewriter.getContext(); + auto loc = op->getLoc(); + + // original values Value ptr = op.ptr(); Value mask = op.mask(); Value other = op.other(); + // adaptor values Value llPtr = adaptor.ptr(); Value llMask = adaptor.mask(); Value llOther = adaptor.other(); - auto loc = op->getLoc(); - MLIRContext *ctx = rewriter.getContext(); - - auto valueTy = op.getResult().getType().dyn_cast(); - if (!valueTy) - return failure(); - Type valueElemTy = - getTypeConverter()->convertType(valueTy.getElementType()); - - auto [layout, numElems] = getLayout(ptr); - - auto ptrElems = getLLVMElems(ptr, llPtr, layout, rewriter, loc); - assert(ptrElems.size() == numElems); // Determine the vectorization size - size_t vec = getVectorizeSize(ptr, layout); + Type valueTy = op.getResult().getType(); + Type valueElemTy = getElementTypeOrSelf(valueTy); + unsigned vec = getVectorSize(ptr); + unsigned numElems = getElemsPerThread(ptr.getType()); + if (llMask) + vec = std::min(vec, getMaskAlignment(mask)); + // Get the LLVM values for pointers + auto ptrElems = getLLVMElems(ptr, llPtr, rewriter, loc); + assert(ptrElems.size() == numElems); + + // Get the LLVM values for mask SmallVector maskElems; if (llMask) { - unsigned maskAlignment = getMaskAlignment(mask); - maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc); - assert(ptrElems.size() == maskElems.size()); - - size_t maskAlign = getMaskAlignment(mask); - vec = std::min(vec, maskAlign); + maskElems = getLLVMElems(mask, llMask, rewriter, loc); + assert(maskElems.size() == numElems); } - const size_t dtsize = - std::max(1, valueElemTy.getIntOrFloatBitWidth() / 8); - const size_t valueElemNbits = dtsize * 8; - - const int numVecs = numElems / vec; - + // Get the LLVM values for `other` // TODO: (goostavz) handle when other is const but not splat, which // should be rarely seen bool otherIsSplatConstInt = false; @@ -898,8 +885,12 @@ struct LoadOpConversion otherIsSplatConstInt = true; splatVal = constAttr.getSplatValue().getSExtValue(); } + auto otherElems = getLLVMElems(other, llOther, rewriter, loc); - auto otherElems = getLLVMElems(other, llOther, layout, rewriter, loc); + // vectorized iteration through all the pointer/mask/other elements + const int valueElemNbits = + std::max(8u, valueElemTy.getIntOrFloatBitWidth()); + const int numVecs = numElems / vec; SmallVector loadedVals; for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { @@ -1060,30 +1051,23 @@ struct StoreOpConversion auto loc = op->getLoc(); MLIRContext *ctx = rewriter.getContext(); - auto valueTy = value.getType().dyn_cast(); - if (!valueTy) { - store(llValue, llPtr); - rewriter.eraseOp(op); - return success(); - } + auto valueTy = value.getType(); + Type valueElemTy = getElementTypeOrSelf(valueTy); - Type valueElemTy = - getTypeConverter()->convertType(valueTy.getElementType()); + unsigned vec = getVectorSize(ptr); + unsigned numElems = getElemsPerThread(ptr.getType()); - auto [layout, numElems] = getLayout(ptr); - - auto ptrElems = getLLVMElems(ptr, llPtr, layout, rewriter, loc); - auto valueElems = getLLVMElems(value, llValue, layout, rewriter, loc); + auto ptrElems = getLLVMElems(ptr, llPtr, rewriter, loc); + auto valueElems = getLLVMElems(value, llValue, rewriter, loc); assert(ptrElems.size() == valueElems.size()); // Determine the vectorization size - size_t vec = getVectorizeSize(ptr, layout); SmallVector maskElems; if (llMask) { - maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc); + maskElems = getLLVMElems(mask, llMask, rewriter, loc); assert(valueElems.size() == maskElems.size()); - size_t maskAlign = getMaskAlignment(mask); + unsigned maskAlign = getMaskAlignment(mask); vec = std::min(vec, maskAlign); } @@ -1146,7 +1130,7 @@ struct StoreOpConversion ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off); auto &ptxStoreInstr = - ptxBuilder.create("st")->global().b(width).v(nWords); + ptxBuilder.create("st")->global().v(nWords).b(width); ptxStoreInstr(asmAddr, asmArgList).predicate(maskVal, "b"); Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1)); @@ -1223,8 +1207,9 @@ struct BroadcastOpConversion for (auto it : llvm::enumerate(broadcastDims)) { // Incase there are multiple indices in the src that is actually // calculating the same element, srcLogicalShape may not need to be 1. - // Such as the case when src of shape [256, 1], and with a blocked layout: - // sizePerThread: [1, 4]; threadsPerWarp: [1, 32]; warpsPerCTA: [1, 2] + // Such as the case when src of shape [256, 1], and with a blocked + // layout: sizePerThread: [1, 4]; threadsPerWarp: [1, 32]; warpsPerCTA: + // [1, 2] int64_t d = resultLogicalShape[it.value()] / srcLogicalShape[it.value()]; broadcastSizes[it.index()] = d; duplicates *= d; @@ -1234,10 +1219,10 @@ struct BroadcastOpConversion duplicates *= d; } - unsigned srcElems = srcLayout.getElemsPerThread(srcShape); + unsigned srcElems = getElemsPerThread(srcTy); auto elemTy = resultTy.getElementType(); auto srcVals = getElementsFromStruct(loc, src, rewriter); - unsigned resultElems = resultLayout.getElemsPerThread(resultShape); + unsigned resultElems = getElemsPerThread(resultTy); SmallVector resultVals(resultElems); for (unsigned i = 0; i < srcElems; ++i) { auto srcMultiDim = getMultiDimIndex(i, srcLogicalShape); @@ -1256,8 +1241,10 @@ struct BroadcastOpConversion } } auto llvmStructTy = getTypeConverter()->convertType(resultTy); + Value resultStruct = getStructFromElements(loc, resultVals, rewriter, llvmStructTy); + rewriter.replaceOp(op, {resultStruct}); return success(); } @@ -1389,7 +1376,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic( auto smemShape = getScratchConfigForReduce(op); - unsigned srcElems = getElemsPerThread(srcLayout, srcShape); + unsigned srcElems = getElemsPerThread(srcTy); auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape); auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter); @@ -1446,7 +1433,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic( auto resultLayout = resultTy.getEncoding(); auto resultShape = resultTy.getShape(); - unsigned resultElems = getElemsPerThread(resultLayout, resultShape); + unsigned resultElems = getElemsPerThread(resultTy); auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultShape); assert(resultIndices.size() == resultElems); @@ -1498,7 +1485,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast( unsigned sizeIntraWarps = threadsPerWarp[axis]; unsigned sizeInterWarps = warpsPerCTA[axis]; - unsigned srcElems = getElemsPerThread(srcLayout, srcShape); + unsigned srcElems = getElemsPerThread(srcTy); auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape); auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter); @@ -1586,7 +1573,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast( auto resultLayout = resultTy.getEncoding().cast(); auto resultShape = resultTy.getShape(); - unsigned resultElems = getElemsPerThread(resultLayout, resultShape); + unsigned resultElems = getElemsPerThread(resultTy); auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultShape); assert(resultIndices.size() == resultElems); @@ -1633,7 +1620,7 @@ struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern { Location loc = op->getLoc(); auto resultTy = op.getType().template cast(); auto resultShape = resultTy.getShape(); - unsigned elems = getElemsPerThread(resultTy.getEncoding(), resultShape); + unsigned elems = getElemsPerThread(resultTy); Type elemTy = this->getTypeConverter()->convertType(resultTy.getElementType()); SmallVector types(elems, elemTy); @@ -1712,7 +1699,7 @@ struct AddPtrOpConversion resultTensorTy.getEncoding().dyn_cast(); assert(resultLayout && "Unexpected resultLayout in AddPtrOpConversion"); auto resultShape = resultTensorTy.getShape(); - unsigned elems = resultLayout.getElemsPerThread(resultShape); + unsigned elems = getElemsPerThread(resultTy); Type elemTy = getTypeConverter()->convertType(resultTensorTy.getElementType()); SmallVector types(elems, elemTy); @@ -1769,8 +1756,8 @@ struct ExtractSliceOpConversion auto srcLayout = srcTy.getEncoding().dyn_cast(); assert(srcLayout && "Unexpected resultLayout in ExtractSliceOpConversion"); - // axis > 0 will result in non-contiguous memory access if the result tensor - // is an alias of the source tensor. + // axis > 0 will result in non-contiguous memory access if the result + // tensor is an alias of the source tensor. auto axis = op->getAttrOfType("axis").getInt(); assert(axis == 0 && "extract_slice: Only axis=0 is supported for now"); @@ -1806,22 +1793,14 @@ public: LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto resultTy = op.getType().template dyn_cast(); - // ArithmeticToLLVM will handle the lowering of scalar ArithOps - if (!resultTy) - return failure(); - + auto resultTy = op.getType(); Location loc = op->getLoc(); - auto resultLayout = - resultTy.getEncoding().template dyn_cast(); - auto resultShape = resultTy.getShape(); - assert(resultLayout && - "Unexpected resultLayout in ElementwiseOpConversionBase"); - unsigned elems = resultLayout.getElemsPerThread(resultShape); - Type elemTy = - this->getTypeConverter()->convertType(resultTy.getElementType()); + + unsigned elems = getElemsPerThread(resultTy); + auto resultElementTy = getElementTypeOrSelf(resultTy); + Type elemTy = this->getTypeConverter()->convertType(resultElementTy); SmallVector types(elems, elemTy); - Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types); + Type structTy = this->getTypeConverter()->convertType(resultTy); auto *concreteThis = static_cast(this); auto operands = getOperands(rewriter, adaptor, elems, loc); @@ -1874,152 +1853,6 @@ struct ElementwiseOpConversion } }; -// -// Ternary -// - -template -class TernaryOpConversionBase - : public ConvertTritonGPUOpToLLVMPattern { -public: - using OpAdaptor = typename SourceOp::Adaptor; - - explicit TernaryOpConversionBase(LLVMTypeConverter &typeConverter, - PatternBenefit benefit = 1) - : ConvertTritonGPUOpToLLVMPattern(typeConverter, benefit) {} - - LogicalResult - matchAndRewrite(SourceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto resultTy = op.getType().template dyn_cast(); - // ArithmeticToLLVM will handle the lowering of scalar ArithOps - if (!resultTy) - return failure(); - - Location loc = op->getLoc(); - auto resultLayout = - resultTy.getEncoding().template dyn_cast(); - auto resultShape = resultTy.getShape(); - assert(resultLayout && "Unexpected resultLayout in TernaryOpConversion"); - unsigned elems = resultLayout.getElemsPerThread(resultShape); - Type elemTy = - this->getTypeConverter()->convertType(resultTy.getElementType()); - SmallVector types(elems, elemTy); - Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types); - - auto *concreteThis = static_cast(this); - auto lhss = - this->getElementsFromStruct(loc, adaptor.getOperands()[0], rewriter); - auto rhss = - this->getElementsFromStruct(loc, adaptor.getOperands()[1], rewriter); - auto thss = - this->getElementsFromStruct(loc, adaptor.getOperands()[2], rewriter); - SmallVector resultVals(elems); - for (unsigned i = 0; i < elems; ++i) { - resultVals[i] = concreteThis->createDestOp(op, rewriter, elemTy, lhss[i], - rhss[i], thss[i], loc); - } - Value view = getStructFromElements(loc, resultVals, rewriter, structTy); - rewriter.replaceOp(op, view); - return success(); - } -}; - -template -struct TernaryOpConversion - : public TernaryOpConversionBase> { - - explicit TernaryOpConversion(LLVMTypeConverter &typeConverter, - PatternBenefit benefit = 1) - : TernaryOpConversionBase>( - typeConverter, benefit) {} - - using OpAdaptor = typename SourceOp::Adaptor; - // An interface to support variant DestOp builder. - DestOp createDestOp(SourceOp op, ConversionPatternRewriter &rewriter, - Type elemTy, Value lhs, Value rhs, Value th, - Location loc) const { - return rewriter.create(loc, elemTy, lhs, rhs, th); - } -}; - -// -// Unary -// - -template -class UnaryOpConversionBase : public ConvertTritonGPUOpToLLVMPattern { - -public: - using OpAdaptor = typename SourceOp::Adaptor; - - explicit UnaryOpConversionBase(LLVMTypeConverter &typeConverter, - PatternBenefit benefit = 1) - : ConvertTritonGPUOpToLLVMPattern(typeConverter, benefit) {} - - LogicalResult - matchAndRewrite(SourceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto resultTy = op.getType().template dyn_cast(); - - // ArithmeticToLLVM will handle the lowering of scalar ArithOps - if (!resultTy) - return failure(); - - Location loc = op->getLoc(); - auto resultLayout = - resultTy.getEncoding().template dyn_cast(); - auto resultShape = resultTy.getShape(); - assert(resultLayout && "Unexpected resultLayout in UnaryOpConversion"); - unsigned elems = resultLayout.getElemsPerThread(resultShape); - Type elemTy = - this->getTypeConverter()->convertType(resultTy.getElementType()); - SmallVector types(elems, elemTy); - Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types); - - auto *concreteThis = static_cast(this); - auto srcs = this->getElementsFromStruct(loc, concreteThis->getSrc(adaptor), - rewriter); - SmallVector resultVals(elems); - for (unsigned i = 0; i < elems; ++i) { - resultVals[i] = - concreteThis->createDestOp(op, rewriter, elemTy, srcs[i], loc); - } - Value view = getStructFromElements(loc, resultVals, rewriter, structTy); - rewriter.replaceOp(op, view); - return success(); - } -}; - -template -struct UnaryOpConversion - : public UnaryOpConversionBase> { - - explicit UnaryOpConversion(LLVMTypeConverter &typeConverter, - PatternBenefit benefit = 1) - : UnaryOpConversionBase>( - typeConverter, benefit) {} - - using OpAdaptor = typename SourceOp::Adaptor; - // An interface to support variant DestOp builder. - DestOp createDestOp(SourceOp op, ConversionPatternRewriter &rewriter, - Type elemTy, Value src, Location loc) const { - return rewriter.create(loc, elemTy, src); - } - - // Get the source operand of the op. - Value getSrc(OpAdaptor adaptor) const { - auto operands = adaptor.getOperands(); - if (operands.size() > 1) - llvm::report_fatal_error("unary operator has more than one operand"); - return operands.front(); - } -}; - // // comparisons // @@ -2367,13 +2200,13 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed( } // Potentially we need to store for multiple CTAs in this replication unsigned accumNumReplicates = product(numReplicates); - unsigned elems = getElemsPerThread(srcLayout, srcTy.getShape()); + unsigned elems = getElemsPerThread(srcTy); auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter); unsigned inVec = 0; unsigned outVec = 0; auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec); - unsigned outElems = getElemsPerThread(dstLayout, shape); + unsigned outElems = getElemsPerThread(dstTy); auto outOrd = getOrder(dstLayout); SmallVector outVals(outElems); @@ -2431,7 +2264,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared( unsigned minVec = std::min(outVec, inVec); unsigned perPhase = dstSharedLayout.getPerPhase(); unsigned maxPhase = dstSharedLayout.getMaxPhase(); - unsigned numElems = getElemsPerThread(srcBlockedLayout, srcShape); + unsigned numElems = getElemsPerThread(srcTy); auto inVals = getElementsFromStruct(loc, adaptor.src(), rewriter); unsigned srcAccumSizeInThreads = product(srcBlockedLayout.getSizePerThread()); @@ -2609,7 +2442,8 @@ public: Value c = urem(lane, i32_val(8)); Value s = udiv(lane, i32_val(8)); // sub-warp-id - // Decompose s => s_0, s_1, that is the coordinate in 2x2 matrices in a warp + // Decompose s => s_0, s_1, that is the coordinate in 2x2 matrices in a + // warp Value s0 = urem(s, i32_val(2)); Value s1 = udiv(s, i32_val(2)); @@ -2756,8 +2590,8 @@ public: llvm::report_fatal_error("unsupported mma type found"); // The main difference with the original triton code is we removed the - // prefetch-related logic here for the upstream optimizer phase should take - // care with it, and that is transparent in dot conversion. + // prefetch-related logic here for the upstream optimizer phase should + // take care with it, and that is transparent in dot conversion. auto getPtr = [&](int idx) { return ptrs[idx]; }; Value ptr = getPtr(ptrIdx); @@ -2768,7 +2602,8 @@ public: matIdx[order[1]] * sMatStride * sMatShape * sTileStride * elemBytes; PTXBuilder builder; - // ldmatrix.m8n8.x4 returns 4x2xfp16(that is 4xb32) elements for a thread. + // ldmatrix.m8n8.x4 returns 4x2xfp16(that is 4xb32) elements for a + // thread. auto resArgs = builder.newListOperand(4, "=r"); auto addrArg = builder.newAddrOperand(ptr, "r", sOffset); @@ -3067,7 +2902,8 @@ struct DotOpConversionHelper { // Get the M and N of mma instruction shape. static std::tuple getInstrShapeMN() { - // According to DotOpConversionHelper::mmaInstrShape, all the M,N are {16,8} + // According to DotOpConversionHelper::mmaInstrShape, all the M,N are + // {16,8} return {16, 8}; } @@ -3808,8 +3644,7 @@ public: if (layout && (layout.isa() || layout.isa() || layout.isa())) { - unsigned numElementsPerThread = - getElemsPerThread(layout, type.getShape()); + unsigned numElementsPerThread = getElemsPerThread(type); SmallVector types(numElementsPerThread, convertType(type.getElementType())); return LLVM::LLVMStructType::getLiteral(ctx, types); @@ -3927,7 +3762,7 @@ struct InsertSliceAsyncOpConversion Value llIndex = adaptor.index(); // %src - auto srcElems = getLLVMElems(src, llSrc, srcBlockedLayout, rewriter, loc); + auto srcElems = getLLVMElems(src, llSrc, rewriter, loc); // %dst auto axis = op->getAttrOfType("axis").getInt(); @@ -3943,7 +3778,7 @@ struct InsertSliceAsyncOpConversion // %mask SmallVector maskElems; if (llMask) { - maskElems = getLLVMElems(mask, llMask, srcBlockedLayout, rewriter, loc); + maskElems = getLLVMElems(mask, llMask, rewriter, loc); assert(srcElems.size() == maskElems.size()); } @@ -3954,15 +3789,14 @@ struct InsertSliceAsyncOpConversion // It's not necessary for now because the pipeline pass will skip // generating insert_slice_async if the load op has any "other" tensor. assert(false && "insert_slice_async: Other value not supported yet"); - otherElems = - getLLVMElems(other, llOther, srcBlockedLayout, rewriter, loc); + otherElems = getLLVMElems(other, llOther, rewriter, loc); assert(srcElems.size() == otherElems.size()); } - unsigned inVec = getVectorizeSize(src, srcBlockedLayout); + unsigned inVec = getVectorSize(src); unsigned outVec = resSharedLayout.getVec(); unsigned minVec = std::min(outVec, inVec); - unsigned numElems = getElemsPerThread(srcBlockedLayout, srcShape); + unsigned numElems = getElemsPerThread(srcTy); unsigned perPhase = resSharedLayout.getPerPhase(); unsigned maxPhase = resSharedLayout.getMaxPhase(); auto sizePerThread = srcBlockedLayout.getSizePerThread(); @@ -4212,6 +4046,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, patterns.add(typeConverter, benefit); #define POPULATE_UNARY_OP(SRC_OP, DST_OP) \ patterns.add>(typeConverter, benefit); + POPULATE_UNARY_OP(arith::TruncIOp, LLVM::TruncOp) POPULATE_UNARY_OP(arith::TruncFOp, LLVM::FPTruncOp) POPULATE_UNARY_OP(arith::ExtSIOp, LLVM::SExtOp) @@ -4221,14 +4056,14 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, POPULATE_UNARY_OP(arith::UIToFPOp, LLVM::UIToFPOp) POPULATE_UNARY_OP(arith::SIToFPOp, LLVM::SIToFPOp) POPULATE_UNARY_OP(arith::ExtFOp, LLVM::FPExtOp) - POPULATE_UNARY_OP(triton::BitcastOp, LLVM::BitcastOp) - POPULATE_UNARY_OP(triton::IntToPtrOp, LLVM::IntToPtrOp) - POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp) POPULATE_UNARY_OP(math::LogOp, math::LogOp) POPULATE_UNARY_OP(math::CosOp, math::CosOp) POPULATE_UNARY_OP(math::SinOp, math::SinOp) POPULATE_UNARY_OP(math::SqrtOp, math::SqrtOp) POPULATE_UNARY_OP(math::ExpOp, math::ExpOp) + POPULATE_UNARY_OP(triton::BitcastOp, LLVM::BitcastOp) + POPULATE_UNARY_OP(triton::IntToPtrOp, LLVM::IntToPtrOp) + POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp) #undef POPULATE_UNARY_OP patterns.add(typeConverter, benefit); diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index d8db733f4..32a0b7204 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -131,6 +131,16 @@ void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, } //-- LoadOp -- +static Type getLoadOpResultType(::mlir::OpBuilder &builder, Type ptrType) { + auto ptrTensorType = ptrType.dyn_cast(); + if (!ptrTensorType) + return ptrType.cast().getPointeeType(); + auto shape = ptrTensorType.getShape(); + Type elementType = + ptrTensorType.getElementType().cast().getPointeeType(); + return RankedTensorType::get(shape, elementType); +} + void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr, ::mlir::triton::CacheModifier cache, ::mlir::triton::EvictionPolicy evict, bool isVolatile) { @@ -150,11 +160,8 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr, ::mlir::Value mask, ::mlir::Value other, ::mlir::triton::CacheModifier cache, ::mlir::triton::EvictionPolicy evict, bool isVolatile) { - TensorType ptrType = ptr.getType().cast(); - Type elementType = - ptrType.getElementType().cast().getPointeeType(); - auto shape = ptrType.getShape(); - Type resultType = RankedTensorType::get(shape, elementType); + Type resultType = getLoadOpResultType(builder, ptr.getType()); + state.addOperands(ptr); if (mask) { state.addOperands(mask); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 2049a18b1..0a0b65406 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -43,7 +43,12 @@ static Type getPointeeType(Type type) { namespace gpu { // TODO: Inheritation of layout attributes -unsigned getElemsPerThread(Attribute layout, ArrayRef shape) { +unsigned getElemsPerThread(Type type) { + if (type.isIntOrIndexOrFloat() || type.isa()) + return 1; + auto tensorType = type.cast(); + auto layout = tensorType.getEncoding(); + auto shape = tensorType.getShape(); size_t rank = shape.size(); if (auto blockedLayout = layout.dyn_cast()) { return blockedLayout.getElemsPerThread(shape); diff --git a/python/tests/test_core.py b/python/tests/test_core.py index e36f436bd..910520fc5 100644 --- a/python/tests/test_core.py +++ b/python/tests/test_core.py @@ -556,45 +556,45 @@ def make_ptr_str(name, shape): # # --------------- -# @triton.jit -# def fn(a, b): -# return a + b, \ -# a - b, \ -# a * b +@triton.jit +def fn(a, b): + return a + b, \ + a - b, \ + a * b -# def test_tuples(): -# device = 'cuda' +def test_tuples(): + device = 'cuda' -# @triton.jit -# def with_fn(X, Y, A, B, C): -# x = tl.load(X) -# y = tl.load(Y) -# a, b, c = fn(x, y) -# tl.store(A, a) -# tl.store(B, b) -# tl.store(C, c) + @triton.jit + def with_fn(X, Y, A, B, C): + x = tl.load(X) + y = tl.load(Y) + a, b, c = fn(x, y) + tl.store(A, a) + tl.store(B, b) + tl.store(C, c) -# @triton.jit -# def without_fn(X, Y, A, B, C): -# x = tl.load(X) -# y = tl.load(Y) -# a, b, c = x + y, x - y, x * y -# tl.store(A, a) -# tl.store(B, b) -# tl.store(C, c) + @triton.jit + def without_fn(X, Y, A, B, C): + x = tl.load(X) + y = tl.load(Y) + a, b, c = x + y, x - y, x * y + tl.store(A, a) + tl.store(B, b) + tl.store(C, c) -# x = torch.tensor([1.3], device=device, dtype=torch.float32) -# y = torch.tensor([1.9], device=device, dtype=torch.float32) -# a_tri = torch.tensor([0], device=device, dtype=torch.float32) -# b_tri = torch.tensor([0], device=device, dtype=torch.float32) -# c_tri = torch.tensor([0], device=device, dtype=torch.float32) -# for kernel in [with_fn, without_fn]: -# kernel[(1, )](x, y, a_tri, b_tri, c_tri, num_warps=1) -# a_ref, b_ref, c_ref = x + y, x - y, x * y -# assert a_tri == a_ref -# assert b_tri == b_ref -# assert c_tri == c_ref + x = torch.tensor([1.3], device=device, dtype=torch.float32) + y = torch.tensor([1.9], device=device, dtype=torch.float32) + a_tri = torch.tensor([0], device=device, dtype=torch.float32) + b_tri = torch.tensor([0], device=device, dtype=torch.float32) + c_tri = torch.tensor([0], device=device, dtype=torch.float32) + for kernel in [with_fn, without_fn]: + kernel[(1, )](x, y, a_tri, b_tri, c_tri, num_warps=1) + a_ref, b_ref, c_ref = x + y, x - y, x * y + assert a_tri == a_ref + assert b_tri == b_ref + assert c_tri == c_ref # # --------------- @@ -709,75 +709,77 @@ def make_ptr_str(name, shape): # # --------------- -# @pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [ -# (dtype_x, dtype_z, False) -# for dtype_x in dtypes -# for dtype_z in dtypes -# ] + [ -# ('float32', 'bfloat16', False), -# ('bfloat16', 'float32', False), -# ('float32', 'int32', True), -# ('float32', 'int1', False), -# ] + [ -# (f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64] -# ] + [ -# (f'int{x}', f'uint{x}', True) for x in [8, 16, 32, 64] -# ]) -# def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): -# # This is tricky because numpy doesn't have bfloat, and torch doesn't have uints. -# x0 = 43 if dtype_x in int_dtypes else 43.5 -# if dtype_x in float_dtypes and dtype_z == 'int1': -# x0 = 0.5 -# if dtype_x.startswith('bfloat'): -# x_tri = torch.tensor([x0], dtype=getattr(torch, dtype_x), device=device) -# else: -# x = np.array([x0], dtype=getattr(np, dtype_x)) -# x_tri = to_triton(x) +@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [ + (dtype_x, dtype_z, False) + for dtype_x in dtypes + for dtype_z in dtypes +] + [ + # TODO: + # ('float32', 'bfloat16', False), + # ('bfloat16', 'float32', False), + ('float32', 'int32', True), + # TODO: + # ('float32', 'int1', False), +] + [ + (f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64] +] + [ + (f'int{x}', f'uint{x}', True) for x in [8, 16, 32, 64] +]) +def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): + # This is tricky because numpy doesn't have bfloat, and torch doesn't have uints. + x0 = 43 if dtype_x in int_dtypes else 43.5 + if dtype_x in float_dtypes and dtype_z == 'int1': + x0 = 0.5 + if dtype_x.startswith('bfloat'): + x_tri = torch.tensor([x0], dtype=getattr(torch, dtype_x), device=device) + else: + x = np.array([x0], dtype=getattr(np, dtype_x)) + x_tri = to_triton(x) -# # triton kernel -# @triton.jit -# def kernel(X, Z, BITCAST: tl.constexpr): -# x = tl.load(X) -# z = x.to(Z.dtype.element_ty, bitcast=BITCAST) -# tl.store(Z, z) + # triton kernel + @triton.jit + def kernel(X, Z, BITCAST: tl.constexpr): + x = tl.load(X) + z = x.to(Z.dtype.element_ty, bitcast=BITCAST) + tl.store(Z, z) -# dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_' -# # triton result -# if dtype_z.startswith('bfloat'): -# z_tri = torch.empty((1,), dtype=getattr(torch, dtype_z), device=device) -# else: -# z_tri = to_triton(np.empty((1, ), dtype=getattr(np, dtype_z_np)), device=device) -# kernel[(1, )](x_tri, z_tri, BITCAST=bitcast) -# # torch result -# if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat'): -# assert bitcast is False -# z_ref = x_tri.to(z_tri.dtype) -# assert z_tri == z_ref -# else: -# if bitcast: -# z_ref = x.view(getattr(np, dtype_z_np)) -# else: -# z_ref = x.astype(getattr(np, dtype_z_np)) -# assert to_numpy(z_tri) == z_ref + dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_' + # triton result + if dtype_z.startswith('bfloat'): + z_tri = torch.empty((1,), dtype=getattr(torch, dtype_z), device=device) + else: + z_tri = to_triton(np.empty((1, ), dtype=getattr(np, dtype_z_np)), device=device) + kernel[(1, )](x_tri, z_tri, BITCAST=bitcast) + # torch result + if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat'): + assert bitcast is False + z_ref = x_tri.to(z_tri.dtype) + assert z_tri == z_ref + else: + if bitcast: + z_ref = x.view(getattr(np, dtype_z_np)) + else: + z_ref = x.astype(getattr(np, dtype_z_np)) + assert to_numpy(z_tri) == z_ref -# def test_store_bool(): -# """Tests that boolean True is stored as 1""" -# @triton.jit -# def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): -# offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) -# mask = offsets < n_elements -# input = tl.load(input_ptr + offsets, mask=mask) -# output = input -# tl.store(output_ptr + offsets, output, mask=mask) +def test_store_bool(): + """Tests that boolean True is stored as 1""" + @triton.jit + def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + input = tl.load(input_ptr + offsets, mask=mask) + output = input + tl.store(output_ptr + offsets, output, mask=mask) -# src = torch.tensor([True, False], dtype=torch.bool, device='cuda') -# n_elements = src.numel() -# dst = torch.empty_like(src) -# grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) -# copy_kernel[grid](src, dst, n_elements, BLOCK_SIZE=1024) + src = torch.tensor([True, False], dtype=torch.bool, device='cuda') + n_elements = src.numel() + dst = torch.empty_like(src) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + copy_kernel[grid](src, dst, n_elements, BLOCK_SIZE=1024) -# assert (to_numpy(src).view('uint8') == to_numpy(dst).view('uint8')).all() + assert (to_numpy(src).view('uint8') == to_numpy(dst).view('uint8')).all() # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @@ -990,48 +992,49 @@ def make_ptr_str(name, shape): # # --------------- -# @pytest.mark.parametrize("dtype_str, shape, perm", -# [(dtype, shape, perm) -# for dtype in ['bfloat16', 'float16', 'float32'] -# for shape in [(64, 64), (128, 128)] -# for perm in [(1, 0)]]) -# def test_permute(dtype_str, shape, perm, device='cuda'): -# check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested +@pytest.mark.parametrize("dtype_str, shape, perm", + [(dtype, shape, perm) + # TODO: bfloat16 + for dtype in ['float16', 'float32'] + for shape in [(64, 64), (128, 128)] + for perm in [(1, 0)]]) +def test_permute(dtype_str, shape, perm, device='cuda'): + check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested -# # triton kernel -# @triton.jit -# def kernel(X, stride_xm, stride_xn, -# Z, stride_zm, stride_zn, -# BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): -# off_m = tl.arange(0, BLOCK_M) -# off_n = tl.arange(0, BLOCK_N) -# Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn -# Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn -# tl.store(Zs, tl.load(Xs)) -# # input -# x = numpy_random(shape, dtype_str=dtype_str) -# # triton result -# z_tri = to_triton(np.empty_like(x), device=device, dst_type=dtype_str) -# z_tri_contiguous = to_triton(np.empty_like(x), device=device, dst_type=dtype_str) -# x_tri = to_triton(x, device=device, dst_type=dtype_str) -# pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), -# z_tri, z_tri.stride(1), z_tri.stride(0), -# BLOCK_M=shape[0], BLOCK_N=shape[1]) -# pgm_contiguous = kernel[(1, 1)](x_tri, x_tri.stride(1), x_tri.stride(0), -# z_tri_contiguous, z_tri_contiguous.stride(0), z_tri_contiguous.stride(1), -# BLOCK_M=shape[0], BLOCK_N=shape[1]) -# # numpy result -# z_ref = x.transpose(*perm) -# # compare -# triton.testing.assert_almost_equal(z_tri, z_ref) -# triton.testing.assert_almost_equal(z_tri_contiguous, z_ref) -# # parse ptx to make sure ld/st are vectorized -# ptx = pgm.asm['ptx'] -# assert 'ld.global.v4' in ptx -# assert 'st.global.v4' in ptx -# ptx = pgm_contiguous.asm['ptx'] -# assert 'ld.global.v4' in ptx -# assert 'st.global.v4' in ptx + # triton kernel + @triton.jit + def kernel(X, stride_xm, stride_xn, + Z, stride_zm, stride_zn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + off_m = tl.arange(0, BLOCK_M) + off_n = tl.arange(0, BLOCK_N) + Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn + Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn + tl.store(Zs, tl.load(Xs)) + # input + x = numpy_random(shape, dtype_str=dtype_str) + # triton result + z_tri = to_triton(np.empty_like(x), device=device, dst_type=dtype_str) + z_tri_contiguous = to_triton(np.empty_like(x), device=device, dst_type=dtype_str) + x_tri = to_triton(x, device=device, dst_type=dtype_str) + pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), + z_tri, z_tri.stride(1), z_tri.stride(0), + BLOCK_M=shape[0], BLOCK_N=shape[1]) + pgm_contiguous = kernel[(1, 1)](x_tri, x_tri.stride(1), x_tri.stride(0), + z_tri_contiguous, z_tri_contiguous.stride(0), z_tri_contiguous.stride(1), + BLOCK_M=shape[0], BLOCK_N=shape[1]) + # numpy result + z_ref = x.transpose(*perm) + # compare + triton.testing.assert_almost_equal(z_tri, z_ref) + triton.testing.assert_almost_equal(z_tri_contiguous, z_ref) + # parse ptx to make sure ld/st are vectorized + ptx = pgm.asm['ptx'] + assert 'ld.global.v4' in ptx + assert 'st.global.v4' in ptx + ptx = pgm_contiguous.asm['ptx'] + assert 'ld.global.v4' in ptx + assert 'st.global.v4' in ptx # # --------------- # # test dot diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 234277bd7..a1a9392d7 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -153,7 +153,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { %13 = tt.addptr %12, %4 : tensor<256x!tt.ptr, #blocked0> // Store 4 elements to global with single one vectorized store instruction - // CHECK: @$5 st.global.b32.v4 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }; + // CHECK: @$5 st.global.v4.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }; tt.store %13, %11 : tensor<256xf32, #blocked0> return } @@ -222,8 +222,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { %13 = tt.addptr %12, %4 : tensor<256x!tt.ptr, #blocked0> // Store 8 elements to global with two vectorized store instruction - // CHECK: @$5 st.global.b32.v4 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }; - // CHECK: @$5 st.global.b32.v4 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }; + // CHECK: @$5 st.global.v4.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }; + // CHECK: @$5 st.global.v4.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }; tt.store %13, %11 : tensor<256xf32, #blocked0> return }