[Triton-MLIR] Generate LLVM/PTX code for async ops (#735)

This commit is contained in:
Keren Zhou
2022-10-04 09:37:00 -07:00
committed by GitHub
parent f9d7f2f126
commit 289ff293cc
9 changed files with 412 additions and 57 deletions

View File

@@ -141,11 +141,12 @@ PTXInstrExecution &PTXInstrCommon::operator()(ArrayRef<Operand *> oprs) {
std::string PTXInstrExecution::dump() const {
std::string osStr;
llvm::raw_string_ostream os(osStr);
if (pred)
if (pred) {
if (!pred->repr)
os << "@" << pred->dump() << " ";
else
os << pred->repr(pred->idx);
}
std::string instrRepr = strJoin(instr->instrParts, ".");

View File

@@ -32,6 +32,7 @@ using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getSizePerThread;
using ::mlir::triton::gpu::getThreadsPerCTA;
using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr;
@@ -75,7 +76,7 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
#define add(...) rewriter.create<LLVM::AddOp>(loc, __VA_ARGS__)
#define mul(...) rewriter.create<LLVM::MulOp>(loc, __VA_ARGS__)
#define xor_(...) rewriter.create<LLVM::XOrOp>(loc, __VA_ARGS__)
#define bit_cast(...) rewriter.create<LLVM::BitcastOp>(loc, __VA_ARGS__)
#define bitcast(...) rewriter.create<LLVM::BitcastOp>(loc, __VA_ARGS__)
#define gep(...) rewriter.create<LLVM::GEPOp>(loc, __VA_ARGS__)
#define ptr_ty(...) LLVM::LLVMPointerType::get(__VA_ARGS__)
#define insert_val(...) rewriter.create<LLVM::InsertValueOp>(loc, __VA_ARGS__)
@@ -86,6 +87,7 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
rewriter.create<LLVM::ExtractElementOp>(loc, __VA_ARGS__)
#define load(...) rewriter.create<LLVM::LoadOp>(loc, __VA_ARGS__)
#define store(val, ptr) rewriter.create<LLVM::StoreOp>(loc, val, ptr)
#define select(...) rewriter.create<LLVM::SelectOp>(loc, __VA_ARGS__)
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
#define barrier rewriter.create<mlir::gpu::BarrierOp>(loc)
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
@@ -630,7 +632,7 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
auto tensorTy = resType.cast<RankedTensorType>();
auto layout = tensorTy.getEncoding();
auto srcType = typeConverter->convertType(elemType);
auto llSrc = bit_cast(srcType, constVal);
auto llSrc = bitcast(srcType, constVal);
size_t elemsPerThread = getElemsPerThread(layout, tensorTy.getShape());
llvm::SmallVector<Value, 4> elems(elemsPerThread, llSrc);
llvm::SmallVector<Type, 4> elemTypes(elems.size(), srcType);
@@ -706,22 +708,14 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
// Get corresponding LLVM element values of \param value.
SmallVector<Value> getLLVMElems(Value value, Value llValue,
const BlockedEncodingAttr &layout,
TypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
Location loc) const {
if (!value)
return {};
auto ty = value.getType().cast<RankedTensorType>();
auto shape = ty.getShape();
auto shape = value.getType().cast<RankedTensorType>().getShape();
// Here, we assume that all inputs should have a blockedLayout
unsigned valueElems = layout.getElemsPerThread(shape);
auto llvmElemTy = typeConverter->convertType(ty.getElementType());
auto llvmElemPtrPtrTy =
LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(llvmElemTy));
auto valueVals = getElementsFromStruct(loc, llValue, valueElems, rewriter);
return valueVals;
}
@@ -810,16 +804,13 @@ struct StoreOpConversion
auto [layout, numElems] = getLayout(ptr);
auto ptrElems =
getLLVMElems(ptr, llPtr, layout, getTypeConverter(), rewriter, loc);
auto valueElems =
getLLVMElems(value, llValue, layout, getTypeConverter(), rewriter, loc);
auto ptrElems = getLLVMElems(ptr, llPtr, layout, rewriter, loc);
auto valueElems = getLLVMElems(value, llValue, layout, rewriter, loc);
assert(ptrElems.size() == valueElems.size());
SmallVector<Value> maskElems;
if (llMask) {
maskElems =
getLLVMElems(mask, llMask, layout, getTypeConverter(), rewriter, loc);
maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc);
assert(valueElems.size() == maskElems.size());
}
@@ -848,7 +839,7 @@ struct StoreOpConversion
const bool hasL2EvictPolicy = false;
PTXBuilder ptxBuilder;
auto &ptxStoreInstr = *ptxBuilder.create<PtxIOInstr>("st");
auto &ptxStoreInstr = *ptxBuilder.create<PTXIOInstr>("st");
llvm::SmallVector<std::string> asmArgs;
@@ -866,7 +857,7 @@ struct StoreOpConversion
Value elem = valueElems[elemOffset];
if (elem.getType().isInteger(1))
elem = rewriter.create<LLVM::SExtOp>(loc, type::i8Ty(ctx), elem);
elem = bit_cast(valueElemTy, elem);
elem = bitcast(valueElemTy, elem);
Type u32Ty = typeConverter->convertType(type::u32Ty(ctx));
llWord =
@@ -874,7 +865,7 @@ struct StoreOpConversion
rewriter.create<LLVM::ConstantOp>(
loc, u32Ty, IntegerAttr::get(u32Ty, elemIdx)));
}
llWord = bit_cast(valArgTy, llWord);
llWord = bitcast(valArgTy, llWord);
std::string constraint =
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
asmArgList->listAppend(ptxBuilder.newOperand(llWord, constraint));
@@ -1100,14 +1091,12 @@ struct LoadOpConversion
auto [layout, numElems] = getLayout(ptr);
auto ptrElems =
getLLVMElems(ptr, llPtr, layout, getTypeConverter(), rewriter, loc);
auto ptrElems = getLLVMElems(ptr, llPtr, layout, rewriter, loc);
assert(ptrElems.size() == numElems);
SmallVector<Value> maskElems;
if (llMask) {
maskElems =
getLLVMElems(mask, llMask, layout, getTypeConverter(), rewriter, loc);
maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc);
assert(ptrElems.size() == maskElems.size());
}
@@ -1132,8 +1121,7 @@ struct LoadOpConversion
splatVal = constAttr.getSplatValue<APInt>().getSExtValue();
}
auto otherElems =
getLLVMElems(other, llOther, layout, getTypeConverter(), rewriter, loc);
auto otherElems = getLLVMElems(other, llOther, layout, rewriter, loc);
SmallVector<Value> loadedVals;
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
@@ -1153,7 +1141,7 @@ struct LoadOpConversion
const bool hasL2EvictPolicy = false;
PTXBuilder ptxBuilder;
auto &ld = *ptxBuilder.create<PtxIOInstr>("ld");
auto &ld = *ptxBuilder.create<PTXIOInstr>("ld");
// TODO(Superjomn) Need to check masks before vectorize the load for all
// the values share one predicate? Here assume all the mask values are
@@ -1198,7 +1186,6 @@ struct LoadOpConversion
else
ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b");
SmallVector<Value> others;
if (other) {
for (size_t ii = 0; ii < nWords; ++ii) {
PTXInstr &mov = *ptxBuilder.create<>("mov");
@@ -1214,14 +1201,13 @@ struct LoadOpConversion
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
v = insert_element(vecTy, v, falseVal, sVal);
}
v = bit_cast(IntegerType::get(getContext(), width), v);
v = bitcast(IntegerType::get(getContext(), width), v);
PTXInstr::Operand *opr{};
if (otherIsSplatConstInt) {
opr = ptxBuilder.newConstantOperand(splatVal);
} else {
opr = ptxBuilder.newOperand(v, readConstraint);
others.push_back(v);
}
mov(dstsOpr->listGet(ii), opr).predicateNot(pred, "b");
@@ -1253,7 +1239,7 @@ struct LoadOpConversion
} else {
curr = ret;
}
curr = bit_cast(
curr = bitcast(
LLVM::getFixedVectorType(valueElemTy, width / valueElemNbits),
curr);
rets.push_back(curr);
@@ -1360,9 +1346,8 @@ struct ExtractSliceOpConversion
// axis > 0 will result in non-contiguous memory access if the result tensor
// is an alias of the source tensor.
auto axis =
op->getAttrOfType<IntegerAttr>("axis").cast<IntegerAttr>().getInt();
assert(axis == 0 && "Only axis=0 is supported for now");
auto axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
assert(axis == 0 && "extract_slice: Only axis=0 is supported for now");
// Example:
// %dst = extract_slice %src, %index {axis = 0}
@@ -1372,12 +1357,11 @@ struct ExtractSliceOpConversion
auto base = product<int64_t>(dstTy.getShape());
auto baseVal = createIndexAttrConstant(
rewriter, loc, getTypeConverter()->getIndexType(), base);
Value offset = rewriter.create<LLVM::MulOp>(loc, adaptor.index(), baseVal);
Value offset = mul(adaptor.index(), baseVal);
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
Value resultVal =
rewriter.create<LLVM::GEPOp>(loc, elemPtrTy, adaptor.src(), offset);
Value resultVal = gep(elemPtrTy, adaptor.src(), offset);
rewriter.replaceOp(op, resultVal);
return success();
}
@@ -1581,7 +1565,7 @@ void ConvertLayoutOpConversion::processReplica(
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
Value ptr = gep(elemPtrTy, smemBase, offset);
auto vecTy = vec_ty(llvmElemTy, vec);
ptr = bit_cast(ptr_ty(vecTy, 3), ptr);
ptr = bitcast(ptr_ty(vecTy, 3), ptr);
if (stNotRd) {
Value valVec = undef(vecTy);
for (unsigned v = 0; v < vec; ++v) {
@@ -1614,7 +1598,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
smemBase = bit_cast(elemPtrTy, smemBase);
smemBase = bitcast(elemPtrTy, smemBase);
auto shape = dstTy.getShape();
unsigned rank = dstTy.getRank();
SmallVector<unsigned> numReplicates(rank);
@@ -1732,7 +1716,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
Value minVecVal = idx_val(minVec);
Value smemBase = getSharedMemoryBase(loc, rewriter, dst);
auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
smemBase = bit_cast(elemPtrTy, smemBase);
smemBase = bitcast(elemPtrTy, smemBase);
unsigned numWordsEachRep = product<unsigned>(wordsInEachRep);
SmallVector<Value> wordVecs(numWordsEachRep);
for (unsigned i = 0; i < numElems; ++i) {
@@ -1783,7 +1767,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
// step 3: store
Value smemAddr = gep(elemPtrTy, smemBase, offset);
smemAddr = bit_cast(ptr_ty(wordTy, 3), smemAddr);
smemAddr = bitcast(ptr_ty(wordTy, 3), smemAddr);
store(wordVecs[linearWordIdx], smemAddr);
}
}
@@ -2126,7 +2110,7 @@ public:
for (int e = 0; e < 4; ++e)
i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m],
i8Elems[m][e], i32_val(e));
i32Elems[m] = bit_cast(i32_ty, i8v4Elems[m]);
i32Elems[m] = bitcast(i32_ty, i8v4Elems[m]);
}
} else { // k first
Value offset = i32_val(sOffsetElem);
@@ -2144,7 +2128,7 @@ public:
for (int e = 0; e < 4; ++e)
i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m],
i8Elems[m][e], i32_val(e));
i32Elems[m] = bit_cast(i32_ty, i8v4Elems[m]);
i32Elems[m] = bitcast(i32_ty, i8v4Elems[m]);
}
}
@@ -2628,7 +2612,7 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
Type smemPtrTy = helper.getShemPtrTy();
for (int i = 0; i < numPtrs; ++i) {
ptrs[i] =
bit_cast(smemPtrTy, gep(smemPtrTy, llTensor, ValueRange({offs[i]})));
bitcast(smemPtrTy, gep(smemPtrTy, llTensor, ValueRange({offs[i]})));
}
bool needTrans = kOrder != order[0];
@@ -2777,6 +2761,229 @@ public:
}
};
struct AsyncWaitOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::AsyncWaitOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::gpu::AsyncWaitOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::gpu::AsyncWaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
PTXBuilder ptxBuilder;
auto &asyncWaitOp = *ptxBuilder.create<PTXCpAsyncWaitGroupInstr>();
auto num = op->getAttrOfType<IntegerAttr>("num").getInt();
asyncWaitOp(ptxBuilder.newConstantOperand(num));
auto ctx = op.getContext();
auto loc = op.getLoc();
auto voidTy = LLVM::LLVMVoidType::get(ctx);
auto ret = ptxBuilder.launch(rewriter, loc, voidTy);
// Safe to remove the op since it doesn't have any return value.
rewriter.eraseOp(op);
return success();
}
};
struct InsertSliceAsyncOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::InsertSliceAsyncOp>,
public LoadStoreConversionBase {
using ConvertTritonGPUOpToLLVMPattern<
triton::gpu::InsertSliceAsyncOp>::ConvertTritonGPUOpToLLVMPattern;
InsertSliceAsyncOpConversion(LLVMTypeConverter &converter,
const Allocation *allocation, Value smem,
AxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::InsertSliceAsyncOp>(
converter, allocation, smem, benefit),
LoadStoreConversionBase(axisAnalysisPass) {}
LogicalResult
matchAndRewrite(triton::gpu::InsertSliceAsyncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// insert_slice_async %src, %dst, %index, %mask, %other
auto loc = op.getLoc();
Value src = op.src();
Value dst = op.dst();
Value res = op.result();
Value mask = op.mask();
Value other = op.other();
assert(allocation->getBufferId(res) == Allocation::InvalidBufferId &&
"Only support in-place insert_slice_async for now");
auto srcTy = src.getType().cast<RankedTensorType>();
auto resTy = dst.getType().cast<RankedTensorType>();
auto resElemTy = resTy.getElementType();
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
auto resSharedLayout = resTy.getEncoding().cast<SharedEncodingAttr>();
auto srcShape = srcTy.getShape();
assert(srcShape.size() == 2 &&
"insert_slice_async: Unexpected rank of %src");
Value llDst = adaptor.dst();
Value llSrc = adaptor.src();
Value llMask = adaptor.mask();
Value llOther = adaptor.other();
Value llIndex = adaptor.index();
// %src
auto srcElems = getLLVMElems(src, llSrc, srcBlockedLayout, rewriter, loc);
// %dst
auto axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
assert(axis == 0 && "insert_slice_async: Only axis=0 is supported for now");
auto dstBase = createIndexAttrConstant(rewriter, loc,
getTypeConverter()->getIndexType(),
product<int64_t>(resTy.getShape()));
Value offset = mul(llIndex, dstBase);
auto dstPtrTy = LLVM::LLVMPointerType::get(
getTypeConverter()->convertType(resTy.getElementType()), 3);
Value dstPtrBase = gep(dstPtrTy, llDst, offset);
// %mask
SmallVector<Value> maskElems;
if (llMask) {
maskElems = getLLVMElems(mask, llMask, srcBlockedLayout, rewriter, loc);
assert(srcElems.size() == maskElems.size());
}
// %other
SmallVector<Value> otherElems;
if (llOther) {
// TODO(Keren): support "other" tensor.
// It's not necessary for now because the pipeline pass will skip
// generating insert_slice_async if the load op has any "other" tensor.
assert(false && "insert_slice_async: Other value not supported yet");
otherElems =
getLLVMElems(other, llOther, srcBlockedLayout, rewriter, loc);
assert(srcElems.size() == otherElems.size());
}
unsigned inVec = getVectorizeSize(src, srcBlockedLayout);
unsigned outVec = resSharedLayout.getVec();
unsigned minVec = std::min(outVec, inVec);
unsigned numElems = getElemsPerThread(srcBlockedLayout, srcShape);
unsigned perPhase = resSharedLayout.getPerPhase();
unsigned maxPhase = resSharedLayout.getMaxPhase();
auto sizePerThread = srcBlockedLayout.getSizePerThread();
auto threadsPerWarp = srcBlockedLayout.getThreadsPerWarp();
auto warpsPerCTA = srcBlockedLayout.getWarpsPerCTA();
auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout);
auto inOrder = srcBlockedLayout.getOrder();
auto outOrder = resSharedLayout.getOrder();
// If perPhase * maxPhase > threadsPerCTA, we need to swizzle over elements
// across phases.
// If perPhase * maxPhase == threadsPerCTA, swizzle is not allowd
auto numSwizzleRows = std::max<unsigned>(
(perPhase * maxPhase) / threadsPerCTA[inOrder[1]], 1);
// A sharedLayout encoding has a "vec" parameter.
// On the column dimension, if inVec > outVec, it means we have to divide
// single vector read into multiple ones
auto numVecCols = std::max<unsigned>(inVec / outVec, 1);
auto srcIndices = emitIndices(loc, rewriter, srcBlockedLayout, srcShape);
// <<tileVecIdxRow, tileVecIdxCol>, TileOffset>
DenseMap<std::pair<unsigned, unsigned>, Value> tileOffsetMap;
for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) {
// minVec = 2, inVec = 4, outVec = 2
// baseOffsetCol = 0 baseOffsetCol = 0
// tileVecIdxCol = 0 tileVecIdxCol = 1
// -/\- -/\-
// [|x x| |x x| x x x x x]
// [|x x| |x x| x x x x x]
// baseOffsetRow [|x x| |x x| x x x x x]
// [|x x| |x x| x x x x x]
auto vecIdx = elemIdx / minVec;
auto vecIdxCol = vecIdx % (sizePerThread[inOrder[0]] / minVec);
auto vecIdxRow = vecIdx / (sizePerThread[inOrder[0]] / minVec);
auto baseOffsetCol =
vecIdxCol / numVecCols * numVecCols * threadsPerCTA[inOrder[0]];
auto baseOffsetRow = vecIdxRow / numSwizzleRows * numSwizzleRows *
threadsPerCTA[inOrder[1]];
auto baseOffset = (baseOffsetRow * srcShape[inOrder[0]] + baseOffsetCol);
auto tileVecIdxCol = vecIdxCol % numVecCols;
auto tileVecIdxRow = vecIdxRow % numSwizzleRows;
if (!tileOffsetMap.count({tileVecIdxRow, tileVecIdxCol})) {
// Swizzling
// Since the swizzling index is related to outVec, and we know minVec
// already, inVec doesn't matter
//
// (Numbers represent row indices)
// Example1:
// outVec = 2, inVec = 2, minVec = 2
// outVec = 2, inVec = 4, minVec = 2
// | [1 2] [3 4] ... [15 16] |
// | [3 4] [5 6] ... [1 2] |
// Example2:
// outVec = 4, inVec = 2, minVec = 2
// | [1 2 3 4] [5 6 7 8] ... [13 14 15 16] |
// | [5 6 7 8] [9 10 11 12] ... [1 2 3 4] |
auto srcIdx = srcIndices[tileVecIdxRow * sizePerThread[inOrder[0]]];
Value phase = urem(udiv(srcIdx[inOrder[1]], i32_val(perPhase)),
i32_val(maxPhase));
Value rowOffset =
mul(srcIdx[inOrder[1]], i32_val(srcShape[inOrder[0]]));
Value colOffset =
add(srcIdx[inOrder[0]], i32_val(tileVecIdxCol * minVec));
Value swizzleIdx = udiv(colOffset, i32_val(outVec));
Value swizzleColOffset =
add(mul(xor_(swizzleIdx, phase), i32_val(outVec)),
urem(colOffset, i32_val(outVec)));
Value tileOffset = add(rowOffset, swizzleColOffset);
tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}] =
gep(dstPtrTy, dstPtrBase, tileOffset);
}
// 16 * 8 = 128bits
auto maxBitWidth =
std::max<unsigned>(128, resElemTy.getIntOrFloatBitWidth());
auto vecBitWidth = resElemTy.getIntOrFloatBitWidth() * minVec;
auto bitWidth = std::min<unsigned>(maxBitWidth, vecBitWidth);
auto numWords = vecBitWidth / bitWidth;
auto numWordElems = bitWidth / resElemTy.getIntOrFloatBitWidth();
// XXX(Keren): Tune CG and CA here.
CacheModifier srcCacheModifier =
bitWidth == 128 ? CacheModifier::CG : CacheModifier::CA;
assert(bitWidth == 128 || bitWidth == 64 || bitWidth == 32);
for (int wordIdx = 0; wordIdx < numWords; ++wordIdx) {
PTXBuilder ptxBuilder;
auto &copyAsyncOp = *ptxBuilder.create<PTXCpAsyncLoadInstr>(
srcCacheModifier, op.evict());
auto tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}];
auto *dstOperand =
ptxBuilder.newAddrOperand(tileOffset, "r", baseOffset);
auto *srcOperand = ptxBuilder.newAddrOperand(srcElems[vecIdx], "l");
auto *copySize = ptxBuilder.newConstantOperand(bitWidth);
auto *srcSize = copySize;
if (op.mask()) {
// We don't use predicate in this case, setting src-size to 0
// if there's any mask. cp.async will automatically fill the
// remaining slots with 0 if cp-size > src-size.
// XXX(Keren): Always assume other = 0 for now.
auto selectOp = select(maskElems[vecIdx + wordIdx * numWordElems],
i32_val(bitWidth), i32_val(0));
srcSize = ptxBuilder.newOperand(selectOp, "r");
}
copyAsyncOp(dstOperand, srcOperand, copySize, srcSize);
ptxBuilder.launch(rewriter, loc, LLVM::LLVMVoidType::get(getContext()));
}
}
PTXBuilder ptxBuilder;
ptxBuilder.create<PTXCpAsyncCommitGroupInstr>()->operator()();
auto ret =
ptxBuilder.launch(rewriter, loc, LLVM::LLVMVoidType::get(getContext()));
rewriter.replaceOp(op, ret);
return success();
}
};
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
@@ -2786,6 +2993,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
patterns.add<AllocTensorOpConversion>(typeConverter, allocation, smem,
benefit);
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
patterns.add<AsyncWaitOpConversion>(typeConverter, benefit);
patterns.add<BinaryOpConversion<arith::AddIOp, LLVM::AddOp>>(typeConverter,
benefit);
patterns.add<BinaryOpConversion<arith::AddFOp, LLVM::FAddOp>>(typeConverter,
@@ -2800,6 +3008,8 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
patterns.add<ExtractSliceOpConversion>(typeConverter, allocation, smem,
benefit);
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
patterns.add<InsertSliceAsyncOpConversion>(typeConverter, allocation, smem,
axisInfoAnalysis, benefit);
patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<MakeRangeOpConversion>(typeConverter, benefit);
patterns.add<ReturnOpConversion>(typeConverter, benefit);