From 289ff293ccc9e06200e68da9de9ff10f82e01bb2 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Tue, 4 Oct 2022 09:37:00 -0700 Subject: [PATCH] [Triton-MLIR] Generate LLVM/PTX code for async ops (#735) --- .../Conversion/TritonGPUToLLVM/PtxAsmFormat.h | 48 ++- .../Dialect/Triton/IR/TritonAttrDefs.td | 2 +- include/triton/Dialect/TritonGPU/IR/Dialect.h | 2 + .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 2 +- .../TritonGPUToLLVM/PtxAsmFormat.cpp | 3 +- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 298 +++++++++++++++--- lib/Dialect/TritonGPU/IR/Dialect.cpp | 15 + test/Conversion/tritongpu_to_llvm.mlir | 97 +++++- .../TritonGPUToLLVM/PtxAsmFormatTest.cpp | 2 +- 9 files changed, 412 insertions(+), 57 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h b/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h index c956eb899..ed051f522 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h +++ b/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h @@ -2,6 +2,7 @@ #define TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_ #include "mlir/IR/Value.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include @@ -99,8 +100,9 @@ struct PTXBuilder { std::string dump() const; }; - template INSTR *create(const std::string &name) { - instrs.emplace_back(std::make_unique(this, name)); + template + INSTR *create(Args &&...args) { + instrs.emplace_back(std::make_unique(this, args...)); return static_cast(instrs.back().get()); } @@ -188,6 +190,7 @@ struct PTXInstrCommon { using Operand = PTXBuilder::Operand; // clang-format off + PTXInstrExecution& operator()() { return call({}); } PTXInstrExecution& operator()(Operand* a) { return call({a}); } PTXInstrExecution& operator()(Operand* a, Operand* b) { return call({a, b}); } PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c) { return call({a, b, c}); } @@ -238,17 +241,17 @@ struct PTXInstr : public PTXInstrBase { // PtxIOInstr store("st"); // store.predicate(pValue).global().v(32).b(1); // @%0 st.global.v32.b1 // store.addAddr(addrValue, "l", off); -struct PtxIOInstr : public PTXInstrBase { - using PTXInstrBase::PTXInstrBase; +struct PTXIOInstr : public PTXInstrBase { + using PTXInstrBase::PTXInstrBase; // Add ".global" suffix to instruction - PtxIOInstr &global(bool predicate = true) { + PTXIOInstr &global(bool predicate = true) { o("global", predicate); return *this; } // Add ".v" suffix to instruction - PtxIOInstr &v(int vecWidth, bool predicate = true) { + PTXIOInstr &v(int vecWidth, bool predicate = true) { if (vecWidth > 1) { o("v" + std::to_string(vecWidth), predicate); } @@ -256,12 +259,43 @@ struct PtxIOInstr : public PTXInstrBase { } // Add ".b" suffix to instruction - PtxIOInstr &b(int width) { + PTXIOInstr &b(int width) { o("b" + std::to_string(width)); return *this; } }; +struct PTXCpAsyncInstrBase : public PTXInstrBase { + explicit PTXCpAsyncInstrBase(PTXBuilder *builder) + : PTXInstrBase(builder, "cp.async") {} +}; + +struct PTXCpAsyncCommitGroupInstr : public PTXCpAsyncInstrBase { + explicit PTXCpAsyncCommitGroupInstr(PTXBuilder *builder) + : PTXCpAsyncInstrBase(builder) { + o("commit_group"); + } +}; + +struct PTXCpAsyncWaitGroupInstr : public PTXCpAsyncInstrBase { + explicit PTXCpAsyncWaitGroupInstr(PTXBuilder *builder) + : PTXCpAsyncInstrBase(builder) { + o("wait_group"); + } +}; + +struct PTXCpAsyncLoadInstr : public PTXCpAsyncInstrBase { + explicit PTXCpAsyncLoadInstr(PTXBuilder *builder, + triton::CacheModifier modifier, + triton::EvictionPolicy policy) + : PTXCpAsyncInstrBase(builder) { + o(triton::stringifyCacheModifier(modifier).str()); + o("shared"); + o("global"); + o("L2::" + triton::stringifyEvictionPolicy(policy).str()); + } +}; + // Record the operands and context for "launching" a PtxInstr. struct PTXInstrExecution { using Operand = PTXBuilder::Operand; diff --git a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td index ac50d2d1a..cd4423041 100644 --- a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td +++ b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td @@ -16,7 +16,7 @@ def TT_CacheModifierAttr : I32EnumAttr< def TT_EvictionPolicyAttr : I32EnumAttr< "EvictionPolicy", "", [ - I32EnumAttrCase<"NORMAL", 1, "normal">, + I32EnumAttrCase<"NORMAL", 1, "evict_normal">, I32EnumAttrCase<"EVICT_FIRST", 2, "evict_first">, I32EnumAttrCase<"EVICT_LAST", 3, "evict_last"> ]> { diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index b9aa13cae..cc62c0f3a 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -24,6 +24,8 @@ unsigned getElemsPerThread(Attribute layout, ArrayRef shape); SmallVector getSizePerThread(Attribute layout); +SmallVector getThreadsPerCTA(const Attribute &layout); + SmallVector getShapePerCTA(const Attribute &layout); SmallVector getOrder(const Attribute &layout); diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 5edd5a51e..930b35dac 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -54,7 +54,7 @@ in memory. For example, a swizzled row-major layout could store its data as follows: A_{0, 0} A_{0, 1} A_{0, 2} A_{0, 3} ... [phase 0] \ per_phase = 2 -A_{1, 0} A_{0, 1} A_{1, 2} A_{1, 3} ... [phase 0] / +A_{1, 0} A_{1, 1} A_{1, 2} A_{1, 3} ... [phase 0] / groups of vec=2 elements are stored contiguously _ _ _ _ /\_ _ _ _ diff --git a/lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp b/lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp index 5827b7301..2fb8c826d 100644 --- a/lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp +++ b/lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp @@ -141,11 +141,12 @@ PTXInstrExecution &PTXInstrCommon::operator()(ArrayRef oprs) { std::string PTXInstrExecution::dump() const { std::string osStr; llvm::raw_string_ostream os(osStr); - if (pred) + if (pred) { if (!pred->repr) os << "@" << pred->dump() << " "; else os << pred->repr(pred->idx); + } std::string instrRepr = strJoin(instr->instrParts, "."); diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index be3a48314..d70763bbe 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -32,6 +32,7 @@ using ::mlir::triton::gpu::getElemsPerThread; using ::mlir::triton::gpu::getOrder; using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getSizePerThread; +using ::mlir::triton::gpu::getThreadsPerCTA; using ::mlir::triton::gpu::MmaEncodingAttr; using ::mlir::triton::gpu::SharedEncodingAttr; using ::mlir::triton::gpu::SliceEncodingAttr; @@ -75,7 +76,7 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, #define add(...) rewriter.create(loc, __VA_ARGS__) #define mul(...) rewriter.create(loc, __VA_ARGS__) #define xor_(...) rewriter.create(loc, __VA_ARGS__) -#define bit_cast(...) rewriter.create(loc, __VA_ARGS__) +#define bitcast(...) rewriter.create(loc, __VA_ARGS__) #define gep(...) rewriter.create(loc, __VA_ARGS__) #define ptr_ty(...) LLVM::LLVMPointerType::get(__VA_ARGS__) #define insert_val(...) rewriter.create(loc, __VA_ARGS__) @@ -86,6 +87,7 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, rewriter.create(loc, __VA_ARGS__) #define load(...) rewriter.create(loc, __VA_ARGS__) #define store(val, ptr) rewriter.create(loc, val, ptr) +#define select(...) rewriter.create(loc, __VA_ARGS__) #define address_of(...) rewriter.create(loc, __VA_ARGS__) #define barrier rewriter.create(loc) #define undef(...) rewriter.create(loc, __VA_ARGS__) @@ -630,7 +632,7 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal, auto tensorTy = resType.cast(); auto layout = tensorTy.getEncoding(); auto srcType = typeConverter->convertType(elemType); - auto llSrc = bit_cast(srcType, constVal); + auto llSrc = bitcast(srcType, constVal); size_t elemsPerThread = getElemsPerThread(layout, tensorTy.getShape()); llvm::SmallVector elems(elemsPerThread, llSrc); llvm::SmallVector elemTypes(elems.size(), srcType); @@ -706,22 +708,14 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase { // Get corresponding LLVM element values of \param value. SmallVector getLLVMElems(Value value, Value llValue, const BlockedEncodingAttr &layout, - TypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc) const { if (!value) return {}; - auto ty = value.getType().cast(); - auto shape = ty.getShape(); + auto shape = value.getType().cast().getShape(); // Here, we assume that all inputs should have a blockedLayout - unsigned valueElems = layout.getElemsPerThread(shape); - - auto llvmElemTy = typeConverter->convertType(ty.getElementType()); - auto llvmElemPtrPtrTy = - LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(llvmElemTy)); - auto valueVals = getElementsFromStruct(loc, llValue, valueElems, rewriter); return valueVals; } @@ -810,16 +804,13 @@ struct StoreOpConversion auto [layout, numElems] = getLayout(ptr); - auto ptrElems = - getLLVMElems(ptr, llPtr, layout, getTypeConverter(), rewriter, loc); - auto valueElems = - getLLVMElems(value, llValue, layout, getTypeConverter(), rewriter, loc); + auto ptrElems = getLLVMElems(ptr, llPtr, layout, rewriter, loc); + auto valueElems = getLLVMElems(value, llValue, layout, rewriter, loc); assert(ptrElems.size() == valueElems.size()); SmallVector maskElems; if (llMask) { - maskElems = - getLLVMElems(mask, llMask, layout, getTypeConverter(), rewriter, loc); + maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc); assert(valueElems.size() == maskElems.size()); } @@ -848,7 +839,7 @@ struct StoreOpConversion const bool hasL2EvictPolicy = false; PTXBuilder ptxBuilder; - auto &ptxStoreInstr = *ptxBuilder.create("st"); + auto &ptxStoreInstr = *ptxBuilder.create("st"); llvm::SmallVector asmArgs; @@ -866,7 +857,7 @@ struct StoreOpConversion Value elem = valueElems[elemOffset]; if (elem.getType().isInteger(1)) elem = rewriter.create(loc, type::i8Ty(ctx), elem); - elem = bit_cast(valueElemTy, elem); + elem = bitcast(valueElemTy, elem); Type u32Ty = typeConverter->convertType(type::u32Ty(ctx)); llWord = @@ -874,7 +865,7 @@ struct StoreOpConversion rewriter.create( loc, u32Ty, IntegerAttr::get(u32Ty, elemIdx))); } - llWord = bit_cast(valArgTy, llWord); + llWord = bitcast(valArgTy, llWord); std::string constraint = (width == 64) ? "l" : ((width == 32) ? "r" : "c"); asmArgList->listAppend(ptxBuilder.newOperand(llWord, constraint)); @@ -1100,14 +1091,12 @@ struct LoadOpConversion auto [layout, numElems] = getLayout(ptr); - auto ptrElems = - getLLVMElems(ptr, llPtr, layout, getTypeConverter(), rewriter, loc); + auto ptrElems = getLLVMElems(ptr, llPtr, layout, rewriter, loc); assert(ptrElems.size() == numElems); SmallVector maskElems; if (llMask) { - maskElems = - getLLVMElems(mask, llMask, layout, getTypeConverter(), rewriter, loc); + maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc); assert(ptrElems.size() == maskElems.size()); } @@ -1132,8 +1121,7 @@ struct LoadOpConversion splatVal = constAttr.getSplatValue().getSExtValue(); } - auto otherElems = - getLLVMElems(other, llOther, layout, getTypeConverter(), rewriter, loc); + auto otherElems = getLLVMElems(other, llOther, layout, rewriter, loc); SmallVector loadedVals; for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { @@ -1153,7 +1141,7 @@ struct LoadOpConversion const bool hasL2EvictPolicy = false; PTXBuilder ptxBuilder; - auto &ld = *ptxBuilder.create("ld"); + auto &ld = *ptxBuilder.create("ld"); // TODO(Superjomn) Need to check masks before vectorize the load for all // the values share one predicate? Here assume all the mask values are @@ -1198,7 +1186,6 @@ struct LoadOpConversion else ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b"); - SmallVector others; if (other) { for (size_t ii = 0; ii < nWords; ++ii) { PTXInstr &mov = *ptxBuilder.create<>("mov"); @@ -1214,14 +1201,13 @@ struct LoadOpConversion rewriter, loc, this->getTypeConverter()->getIndexType(), s); v = insert_element(vecTy, v, falseVal, sVal); } - v = bit_cast(IntegerType::get(getContext(), width), v); + v = bitcast(IntegerType::get(getContext(), width), v); PTXInstr::Operand *opr{}; if (otherIsSplatConstInt) { opr = ptxBuilder.newConstantOperand(splatVal); } else { opr = ptxBuilder.newOperand(v, readConstraint); - others.push_back(v); } mov(dstsOpr->listGet(ii), opr).predicateNot(pred, "b"); @@ -1253,7 +1239,7 @@ struct LoadOpConversion } else { curr = ret; } - curr = bit_cast( + curr = bitcast( LLVM::getFixedVectorType(valueElemTy, width / valueElemNbits), curr); rets.push_back(curr); @@ -1360,9 +1346,8 @@ struct ExtractSliceOpConversion // axis > 0 will result in non-contiguous memory access if the result tensor // is an alias of the source tensor. - auto axis = - op->getAttrOfType("axis").cast().getInt(); - assert(axis == 0 && "Only axis=0 is supported for now"); + auto axis = op->getAttrOfType("axis").getInt(); + assert(axis == 0 && "extract_slice: Only axis=0 is supported for now"); // Example: // %dst = extract_slice %src, %index {axis = 0} @@ -1372,12 +1357,11 @@ struct ExtractSliceOpConversion auto base = product(dstTy.getShape()); auto baseVal = createIndexAttrConstant( rewriter, loc, getTypeConverter()->getIndexType(), base); - Value offset = rewriter.create(loc, adaptor.index(), baseVal); + Value offset = mul(adaptor.index(), baseVal); auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType()); auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); - Value resultVal = - rewriter.create(loc, elemPtrTy, adaptor.src(), offset); + Value resultVal = gep(elemPtrTy, adaptor.src(), offset); rewriter.replaceOp(op, resultVal); return success(); } @@ -1581,7 +1565,7 @@ void ConvertLayoutOpConversion::processReplica( auto elemPtrTy = ptr_ty(llvmElemTy, 3); Value ptr = gep(elemPtrTy, smemBase, offset); auto vecTy = vec_ty(llvmElemTy, vec); - ptr = bit_cast(ptr_ty(vecTy, 3), ptr); + ptr = bitcast(ptr_ty(vecTy, 3), ptr); if (stNotRd) { Value valVec = undef(vecTy); for (unsigned v = 0; v < vec; ++v) { @@ -1614,7 +1598,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed( auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType()); Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); auto elemPtrTy = ptr_ty(llvmElemTy, 3); - smemBase = bit_cast(elemPtrTy, smemBase); + smemBase = bitcast(elemPtrTy, smemBase); auto shape = dstTy.getShape(); unsigned rank = dstTy.getRank(); SmallVector numReplicates(rank); @@ -1732,7 +1716,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared( Value minVecVal = idx_val(minVec); Value smemBase = getSharedMemoryBase(loc, rewriter, dst); auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3); - smemBase = bit_cast(elemPtrTy, smemBase); + smemBase = bitcast(elemPtrTy, smemBase); unsigned numWordsEachRep = product(wordsInEachRep); SmallVector wordVecs(numWordsEachRep); for (unsigned i = 0; i < numElems; ++i) { @@ -1783,7 +1767,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared( // step 3: store Value smemAddr = gep(elemPtrTy, smemBase, offset); - smemAddr = bit_cast(ptr_ty(wordTy, 3), smemAddr); + smemAddr = bitcast(ptr_ty(wordTy, 3), smemAddr); store(wordVecs[linearWordIdx], smemAddr); } } @@ -2126,7 +2110,7 @@ public: for (int e = 0; e < 4; ++e) i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m], i8Elems[m][e], i32_val(e)); - i32Elems[m] = bit_cast(i32_ty, i8v4Elems[m]); + i32Elems[m] = bitcast(i32_ty, i8v4Elems[m]); } } else { // k first Value offset = i32_val(sOffsetElem); @@ -2144,7 +2128,7 @@ public: for (int e = 0; e < 4; ++e) i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m], i8Elems[m][e], i32_val(e)); - i32Elems[m] = bit_cast(i32_ty, i8v4Elems[m]); + i32Elems[m] = bitcast(i32_ty, i8v4Elems[m]); } } @@ -2628,7 +2612,7 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter, Type smemPtrTy = helper.getShemPtrTy(); for (int i = 0; i < numPtrs; ++i) { ptrs[i] = - bit_cast(smemPtrTy, gep(smemPtrTy, llTensor, ValueRange({offs[i]}))); + bitcast(smemPtrTy, gep(smemPtrTy, llTensor, ValueRange({offs[i]}))); } bool needTrans = kOrder != order[0]; @@ -2777,6 +2761,229 @@ public: } }; +struct AsyncWaitOpConversion + : public ConvertTritonGPUOpToLLVMPattern { + using ConvertTritonGPUOpToLLVMPattern< + triton::gpu::AsyncWaitOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::gpu::AsyncWaitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + PTXBuilder ptxBuilder; + auto &asyncWaitOp = *ptxBuilder.create(); + auto num = op->getAttrOfType("num").getInt(); + asyncWaitOp(ptxBuilder.newConstantOperand(num)); + + auto ctx = op.getContext(); + auto loc = op.getLoc(); + auto voidTy = LLVM::LLVMVoidType::get(ctx); + auto ret = ptxBuilder.launch(rewriter, loc, voidTy); + + // Safe to remove the op since it doesn't have any return value. + rewriter.eraseOp(op); + return success(); + } +}; + +struct InsertSliceAsyncOpConversion + : public ConvertTritonGPUOpToLLVMPattern, + public LoadStoreConversionBase { + using ConvertTritonGPUOpToLLVMPattern< + triton::gpu::InsertSliceAsyncOp>::ConvertTritonGPUOpToLLVMPattern; + + InsertSliceAsyncOpConversion(LLVMTypeConverter &converter, + const Allocation *allocation, Value smem, + AxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertTritonGPUOpToLLVMPattern( + converter, allocation, smem, benefit), + LoadStoreConversionBase(axisAnalysisPass) {} + + LogicalResult + matchAndRewrite(triton::gpu::InsertSliceAsyncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // insert_slice_async %src, %dst, %index, %mask, %other + auto loc = op.getLoc(); + Value src = op.src(); + Value dst = op.dst(); + Value res = op.result(); + Value mask = op.mask(); + Value other = op.other(); + assert(allocation->getBufferId(res) == Allocation::InvalidBufferId && + "Only support in-place insert_slice_async for now"); + + auto srcTy = src.getType().cast(); + auto resTy = dst.getType().cast(); + auto resElemTy = resTy.getElementType(); + auto srcBlockedLayout = srcTy.getEncoding().cast(); + auto resSharedLayout = resTy.getEncoding().cast(); + auto srcShape = srcTy.getShape(); + assert(srcShape.size() == 2 && + "insert_slice_async: Unexpected rank of %src"); + + Value llDst = adaptor.dst(); + Value llSrc = adaptor.src(); + Value llMask = adaptor.mask(); + Value llOther = adaptor.other(); + Value llIndex = adaptor.index(); + + // %src + auto srcElems = getLLVMElems(src, llSrc, srcBlockedLayout, rewriter, loc); + + // %dst + auto axis = op->getAttrOfType("axis").getInt(); + assert(axis == 0 && "insert_slice_async: Only axis=0 is supported for now"); + auto dstBase = createIndexAttrConstant(rewriter, loc, + getTypeConverter()->getIndexType(), + product(resTy.getShape())); + Value offset = mul(llIndex, dstBase); + auto dstPtrTy = LLVM::LLVMPointerType::get( + getTypeConverter()->convertType(resTy.getElementType()), 3); + Value dstPtrBase = gep(dstPtrTy, llDst, offset); + + // %mask + SmallVector maskElems; + if (llMask) { + maskElems = getLLVMElems(mask, llMask, srcBlockedLayout, rewriter, loc); + assert(srcElems.size() == maskElems.size()); + } + + // %other + SmallVector otherElems; + if (llOther) { + // TODO(Keren): support "other" tensor. + // It's not necessary for now because the pipeline pass will skip + // generating insert_slice_async if the load op has any "other" tensor. + assert(false && "insert_slice_async: Other value not supported yet"); + otherElems = + getLLVMElems(other, llOther, srcBlockedLayout, rewriter, loc); + assert(srcElems.size() == otherElems.size()); + } + + unsigned inVec = getVectorizeSize(src, srcBlockedLayout); + unsigned outVec = resSharedLayout.getVec(); + unsigned minVec = std::min(outVec, inVec); + unsigned numElems = getElemsPerThread(srcBlockedLayout, srcShape); + unsigned perPhase = resSharedLayout.getPerPhase(); + unsigned maxPhase = resSharedLayout.getMaxPhase(); + auto sizePerThread = srcBlockedLayout.getSizePerThread(); + auto threadsPerWarp = srcBlockedLayout.getThreadsPerWarp(); + auto warpsPerCTA = srcBlockedLayout.getWarpsPerCTA(); + auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout); + + auto inOrder = srcBlockedLayout.getOrder(); + auto outOrder = resSharedLayout.getOrder(); + // If perPhase * maxPhase > threadsPerCTA, we need to swizzle over elements + // across phases. + // If perPhase * maxPhase == threadsPerCTA, swizzle is not allowd + auto numSwizzleRows = std::max( + (perPhase * maxPhase) / threadsPerCTA[inOrder[1]], 1); + // A sharedLayout encoding has a "vec" parameter. + // On the column dimension, if inVec > outVec, it means we have to divide + // single vector read into multiple ones + auto numVecCols = std::max(inVec / outVec, 1); + + auto srcIndices = emitIndices(loc, rewriter, srcBlockedLayout, srcShape); + // <, TileOffset> + DenseMap, Value> tileOffsetMap; + for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) { + // minVec = 2, inVec = 4, outVec = 2 + // baseOffsetCol = 0 baseOffsetCol = 0 + // tileVecIdxCol = 0 tileVecIdxCol = 1 + // -/\- -/\- + // [|x x| |x x| x x x x x] + // [|x x| |x x| x x x x x] + // baseOffsetRow [|x x| |x x| x x x x x] + // [|x x| |x x| x x x x x] + auto vecIdx = elemIdx / minVec; + auto vecIdxCol = vecIdx % (sizePerThread[inOrder[0]] / minVec); + auto vecIdxRow = vecIdx / (sizePerThread[inOrder[0]] / minVec); + auto baseOffsetCol = + vecIdxCol / numVecCols * numVecCols * threadsPerCTA[inOrder[0]]; + auto baseOffsetRow = vecIdxRow / numSwizzleRows * numSwizzleRows * + threadsPerCTA[inOrder[1]]; + auto baseOffset = (baseOffsetRow * srcShape[inOrder[0]] + baseOffsetCol); + auto tileVecIdxCol = vecIdxCol % numVecCols; + auto tileVecIdxRow = vecIdxRow % numSwizzleRows; + + if (!tileOffsetMap.count({tileVecIdxRow, tileVecIdxCol})) { + // Swizzling + // Since the swizzling index is related to outVec, and we know minVec + // already, inVec doesn't matter + // + // (Numbers represent row indices) + // Example1: + // outVec = 2, inVec = 2, minVec = 2 + // outVec = 2, inVec = 4, minVec = 2 + // | [1 2] [3 4] ... [15 16] | + // | [3 4] [5 6] ... [1 2] | + // Example2: + // outVec = 4, inVec = 2, minVec = 2 + // | [1 2 3 4] [5 6 7 8] ... [13 14 15 16] | + // | [5 6 7 8] [9 10 11 12] ... [1 2 3 4] | + auto srcIdx = srcIndices[tileVecIdxRow * sizePerThread[inOrder[0]]]; + Value phase = urem(udiv(srcIdx[inOrder[1]], i32_val(perPhase)), + i32_val(maxPhase)); + Value rowOffset = + mul(srcIdx[inOrder[1]], i32_val(srcShape[inOrder[0]])); + Value colOffset = + add(srcIdx[inOrder[0]], i32_val(tileVecIdxCol * minVec)); + Value swizzleIdx = udiv(colOffset, i32_val(outVec)); + Value swizzleColOffset = + add(mul(xor_(swizzleIdx, phase), i32_val(outVec)), + urem(colOffset, i32_val(outVec))); + Value tileOffset = add(rowOffset, swizzleColOffset); + tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}] = + gep(dstPtrTy, dstPtrBase, tileOffset); + } + + // 16 * 8 = 128bits + auto maxBitWidth = + std::max(128, resElemTy.getIntOrFloatBitWidth()); + auto vecBitWidth = resElemTy.getIntOrFloatBitWidth() * minVec; + auto bitWidth = std::min(maxBitWidth, vecBitWidth); + auto numWords = vecBitWidth / bitWidth; + auto numWordElems = bitWidth / resElemTy.getIntOrFloatBitWidth(); + + // XXX(Keren): Tune CG and CA here. + CacheModifier srcCacheModifier = + bitWidth == 128 ? CacheModifier::CG : CacheModifier::CA; + assert(bitWidth == 128 || bitWidth == 64 || bitWidth == 32); + + for (int wordIdx = 0; wordIdx < numWords; ++wordIdx) { + PTXBuilder ptxBuilder; + auto ©AsyncOp = *ptxBuilder.create( + srcCacheModifier, op.evict()); + + auto tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}]; + auto *dstOperand = + ptxBuilder.newAddrOperand(tileOffset, "r", baseOffset); + auto *srcOperand = ptxBuilder.newAddrOperand(srcElems[vecIdx], "l"); + auto *copySize = ptxBuilder.newConstantOperand(bitWidth); + auto *srcSize = copySize; + if (op.mask()) { + // We don't use predicate in this case, setting src-size to 0 + // if there's any mask. cp.async will automatically fill the + // remaining slots with 0 if cp-size > src-size. + // XXX(Keren): Always assume other = 0 for now. + auto selectOp = select(maskElems[vecIdx + wordIdx * numWordElems], + i32_val(bitWidth), i32_val(0)); + srcSize = ptxBuilder.newOperand(selectOp, "r"); + } + copyAsyncOp(dstOperand, srcOperand, copySize, srcSize); + ptxBuilder.launch(rewriter, loc, LLVM::LLVMVoidType::get(getContext())); + } + } + + PTXBuilder ptxBuilder; + ptxBuilder.create()->operator()(); + auto ret = + ptxBuilder.launch(rewriter, loc, LLVM::LLVMVoidType::get(getContext())); + rewriter.replaceOp(op, ret); + return success(); + } +}; + void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, AxisInfoAnalysis &axisInfoAnalysis, @@ -2786,6 +2993,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, patterns.add(typeConverter, allocation, smem, benefit); patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); patterns.add>(typeConverter, benefit); patterns.add>(typeConverter, @@ -2800,6 +3008,8 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, patterns.add(typeConverter, allocation, smem, benefit); patterns.add(typeConverter, benefit); + patterns.add(typeConverter, allocation, smem, + axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index d531c0fca..4ee053b19 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -72,6 +72,21 @@ SmallVector getSizePerThread(Attribute layout) { } } +SmallVector getThreadsPerCTA(const Attribute &layout) { + SmallVector threads; + if (auto blockedLayout = layout.dyn_cast()) { + for (int d = 0, n = blockedLayout.getOrder().size(); d < n; ++d) + threads.push_back(blockedLayout.getThreadsPerWarp()[d] * + blockedLayout.getWarpsPerCTA()[d]); + } else if (auto mmaLayout = layout.dyn_cast()) { + assert(0 && "Unimplemented usage of MmaEncodingAttr"); + } else { + assert(0 && "Unimplemented usage of getShapePerCTA"); + } + + return threads; +} + SmallVector getShapePerCTA(const Attribute &layout) { SmallVector shape; if (auto blockedLayout = layout.dyn_cast()) { diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index f54a8d983..f22371ee8 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -333,6 +333,99 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- +module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_async_wait + func @basic_async_wait() { + // CHECK: cp.async.wait_group 0x4 + triton_gpu.async_wait {num = 4: i32} + return + } +} + +// ----- + +#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [4], warpsPerCTA = [4], order = [0]}> +#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}> +#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}> +#AL = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#A = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 4, order = [1, 0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_insert_slice_async_v4 + func @basic_insert_slice_async_v4(%arg0: !tt.ptr {tt.divisibility = 4 : i32}) { + %off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #block0> + %off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #block1> + %off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #block0>) -> tensor<16x1xi32, #block2> + %off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<64xi32, #block1>) -> tensor<1x64xi32, #block3> + %broadcast_off0_scalar = tt.broadcast %off0 : (tensor<16x1xi32, #block2>) -> tensor<16x64xi32, #block2> + %cst_scalar = arith.constant 64 : i32 + %cst = tt.splat %cst_scalar : (i32) -> tensor<16x64xi32, #block2> + %broadcast_off0_ = arith.muli %broadcast_off0_scalar, %cst : tensor<16x64xi32, #block2> + %broadcast_off1_ = tt.broadcast %off1 : (tensor<1x64xi32, #block3>) -> tensor<16x64xi32, #block3> + %broadcast_off0 = triton_gpu.convert_layout %broadcast_off0_ : (tensor<16x64xi32, #block2>) -> tensor<16x64xi32, #AL> + %broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : (tensor<16x64xi32, #block3>) -> tensor<16x64xi32, #AL> + %off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x64xi32, #AL> + %a_init = tt.splat %arg0 : (!tt.ptr) -> tensor<16x64x!tt.ptr, #AL> + %a_ptr = tt.addptr %a_init, %off : tensor<16x64x!tt.ptr, #AL> + %tensor = triton_gpu.alloc_tensor : tensor<2x16x64xf32, #A> + %index = arith.constant 1 : i32 + + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att + // CHECK-SAME: cp.async.cg.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x80, 0x80 + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att + // CHECK-SAME: cp.async.cg.shared.global.L2::evict_normal [ ${{.*}} + 8 ], [ ${{.*}} + 0 ], 0x80, 0x80 + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att + // CHECK-SAME: cp.async.commit_group + %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64x!tt.ptr, #AL> -> tensor<2x16x64xf32, #A> + return + } +} + +// ----- + +#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [4], warpsPerCTA = [4], order = [0]}> +#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}> +#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}> +#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_insert_slice_async_v1 + func @basic_insert_slice_async_v1(%arg0: !tt.ptr {tt.divisibility = 4 : i32}) { + %off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #block0> + %off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #block1> + %off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #block0>) -> tensor<16x1xi32, #block2> + %off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<32xi32, #block1>) -> tensor<1x32xi32, #block3> + %broadcast_off0_scalar = tt.broadcast %off0 : (tensor<16x1xi32, #block2>) -> tensor<16x32xi32, #block2> + %cst_scalar = arith.constant 32 : i32 + %cst = tt.splat %cst_scalar : (i32) -> tensor<16x32xi32, #block2> + %broadcast_off0_ = arith.muli %broadcast_off0_scalar, %cst : tensor<16x32xi32, #block2> + %broadcast_off1_ = tt.broadcast %off1 : (tensor<1x32xi32, #block3>) -> tensor<16x32xi32, #block3> + %broadcast_off0 = triton_gpu.convert_layout %broadcast_off0_ : (tensor<16x32xi32, #block2>) -> tensor<16x32xi32, #AL> + %broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : (tensor<16x32xi32, #block3>) -> tensor<16x32xi32, #AL> + %off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x32xi32, #AL> + %a_init = tt.splat %arg0 : (!tt.ptr) -> tensor<16x32x!tt.ptr, #AL> + %a_ptr = tt.addptr %a_init, %off : tensor<16x32x!tt.ptr, #AL> + %tensor = triton_gpu.alloc_tensor : tensor<2x16x32xf32, #A> + %index = arith.constant 1 : i32 + + // CHECK: llvm.inline_asm + // CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20 + // CHECK: llvm.inline_asm + // CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20 + // CHECK: llvm.inline_asm + // CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20 + // CHECK: llvm.inline_asm + // CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20 + // CHECK: llvm.inline_asm + // CHECK-SAME: cp.async.commit_group + %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x32x!tt.ptr, #AL> -> tensor<2x16x32xf32, #A> + return + } +} + +// ----- + #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK: basic_splat @@ -351,9 +444,9 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_store func @basic_store(%ptrs: tensor<256x!tt.ptr, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) { - // CHECK: llvm.inline_asm has_side_effects asm_dialect = att + // CHECK: llvm.inline_asm // CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} }; - // CHECK: llvm.inline_asm has_side_effects asm_dialect = att + // CHECK: llvm.inline_asm // CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} }; tt.store %ptrs, %vals, %mask : tensor<256xf32, #blocked0> return diff --git a/unittest/Conversion/TritonGPUToLLVM/PtxAsmFormatTest.cpp b/unittest/Conversion/TritonGPUToLLVM/PtxAsmFormatTest.cpp index 9668a92ca..efce89a67 100644 --- a/unittest/Conversion/TritonGPUToLLVM/PtxAsmFormatTest.cpp +++ b/unittest/Conversion/TritonGPUToLLVM/PtxAsmFormatTest.cpp @@ -76,7 +76,7 @@ TEST_F(PtxAsmFormatTest, complexInstruction) { auto &ld = builder - .create("ld") // + .create("ld") // ->o("volatile", isVolatile) .global() .o("ca", cache == CacheModifier::CA)