[Triton-MLIR] Generate LLVM/PTX code for async ops (#735)
This commit is contained in:
@@ -141,11 +141,12 @@ PTXInstrExecution &PTXInstrCommon::operator()(ArrayRef<Operand *> oprs) {
|
||||
std::string PTXInstrExecution::dump() const {
|
||||
std::string osStr;
|
||||
llvm::raw_string_ostream os(osStr);
|
||||
if (pred)
|
||||
if (pred) {
|
||||
if (!pred->repr)
|
||||
os << "@" << pred->dump() << " ";
|
||||
else
|
||||
os << pred->repr(pred->idx);
|
||||
}
|
||||
|
||||
std::string instrRepr = strJoin(instr->instrParts, ".");
|
||||
|
||||
|
@@ -32,6 +32,7 @@ using ::mlir::triton::gpu::getElemsPerThread;
|
||||
using ::mlir::triton::gpu::getOrder;
|
||||
using ::mlir::triton::gpu::getShapePerCTA;
|
||||
using ::mlir::triton::gpu::getSizePerThread;
|
||||
using ::mlir::triton::gpu::getThreadsPerCTA;
|
||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
using ::mlir::triton::gpu::SliceEncodingAttr;
|
||||
@@ -75,7 +76,7 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
||||
#define add(...) rewriter.create<LLVM::AddOp>(loc, __VA_ARGS__)
|
||||
#define mul(...) rewriter.create<LLVM::MulOp>(loc, __VA_ARGS__)
|
||||
#define xor_(...) rewriter.create<LLVM::XOrOp>(loc, __VA_ARGS__)
|
||||
#define bit_cast(...) rewriter.create<LLVM::BitcastOp>(loc, __VA_ARGS__)
|
||||
#define bitcast(...) rewriter.create<LLVM::BitcastOp>(loc, __VA_ARGS__)
|
||||
#define gep(...) rewriter.create<LLVM::GEPOp>(loc, __VA_ARGS__)
|
||||
#define ptr_ty(...) LLVM::LLVMPointerType::get(__VA_ARGS__)
|
||||
#define insert_val(...) rewriter.create<LLVM::InsertValueOp>(loc, __VA_ARGS__)
|
||||
@@ -86,6 +87,7 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
||||
rewriter.create<LLVM::ExtractElementOp>(loc, __VA_ARGS__)
|
||||
#define load(...) rewriter.create<LLVM::LoadOp>(loc, __VA_ARGS__)
|
||||
#define store(val, ptr) rewriter.create<LLVM::StoreOp>(loc, val, ptr)
|
||||
#define select(...) rewriter.create<LLVM::SelectOp>(loc, __VA_ARGS__)
|
||||
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
|
||||
#define barrier rewriter.create<mlir::gpu::BarrierOp>(loc)
|
||||
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
|
||||
@@ -630,7 +632,7 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
|
||||
auto tensorTy = resType.cast<RankedTensorType>();
|
||||
auto layout = tensorTy.getEncoding();
|
||||
auto srcType = typeConverter->convertType(elemType);
|
||||
auto llSrc = bit_cast(srcType, constVal);
|
||||
auto llSrc = bitcast(srcType, constVal);
|
||||
size_t elemsPerThread = getElemsPerThread(layout, tensorTy.getShape());
|
||||
llvm::SmallVector<Value, 4> elems(elemsPerThread, llSrc);
|
||||
llvm::SmallVector<Type, 4> elemTypes(elems.size(), srcType);
|
||||
@@ -706,22 +708,14 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
|
||||
// Get corresponding LLVM element values of \param value.
|
||||
SmallVector<Value> getLLVMElems(Value value, Value llValue,
|
||||
const BlockedEncodingAttr &layout,
|
||||
TypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Location loc) const {
|
||||
if (!value)
|
||||
return {};
|
||||
|
||||
auto ty = value.getType().cast<RankedTensorType>();
|
||||
auto shape = ty.getShape();
|
||||
auto shape = value.getType().cast<RankedTensorType>().getShape();
|
||||
// Here, we assume that all inputs should have a blockedLayout
|
||||
|
||||
unsigned valueElems = layout.getElemsPerThread(shape);
|
||||
|
||||
auto llvmElemTy = typeConverter->convertType(ty.getElementType());
|
||||
auto llvmElemPtrPtrTy =
|
||||
LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(llvmElemTy));
|
||||
|
||||
auto valueVals = getElementsFromStruct(loc, llValue, valueElems, rewriter);
|
||||
return valueVals;
|
||||
}
|
||||
@@ -810,16 +804,13 @@ struct StoreOpConversion
|
||||
|
||||
auto [layout, numElems] = getLayout(ptr);
|
||||
|
||||
auto ptrElems =
|
||||
getLLVMElems(ptr, llPtr, layout, getTypeConverter(), rewriter, loc);
|
||||
auto valueElems =
|
||||
getLLVMElems(value, llValue, layout, getTypeConverter(), rewriter, loc);
|
||||
auto ptrElems = getLLVMElems(ptr, llPtr, layout, rewriter, loc);
|
||||
auto valueElems = getLLVMElems(value, llValue, layout, rewriter, loc);
|
||||
assert(ptrElems.size() == valueElems.size());
|
||||
|
||||
SmallVector<Value> maskElems;
|
||||
if (llMask) {
|
||||
maskElems =
|
||||
getLLVMElems(mask, llMask, layout, getTypeConverter(), rewriter, loc);
|
||||
maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc);
|
||||
assert(valueElems.size() == maskElems.size());
|
||||
}
|
||||
|
||||
@@ -848,7 +839,7 @@ struct StoreOpConversion
|
||||
const bool hasL2EvictPolicy = false;
|
||||
|
||||
PTXBuilder ptxBuilder;
|
||||
auto &ptxStoreInstr = *ptxBuilder.create<PtxIOInstr>("st");
|
||||
auto &ptxStoreInstr = *ptxBuilder.create<PTXIOInstr>("st");
|
||||
|
||||
llvm::SmallVector<std::string> asmArgs;
|
||||
|
||||
@@ -866,7 +857,7 @@ struct StoreOpConversion
|
||||
Value elem = valueElems[elemOffset];
|
||||
if (elem.getType().isInteger(1))
|
||||
elem = rewriter.create<LLVM::SExtOp>(loc, type::i8Ty(ctx), elem);
|
||||
elem = bit_cast(valueElemTy, elem);
|
||||
elem = bitcast(valueElemTy, elem);
|
||||
|
||||
Type u32Ty = typeConverter->convertType(type::u32Ty(ctx));
|
||||
llWord =
|
||||
@@ -874,7 +865,7 @@ struct StoreOpConversion
|
||||
rewriter.create<LLVM::ConstantOp>(
|
||||
loc, u32Ty, IntegerAttr::get(u32Ty, elemIdx)));
|
||||
}
|
||||
llWord = bit_cast(valArgTy, llWord);
|
||||
llWord = bitcast(valArgTy, llWord);
|
||||
std::string constraint =
|
||||
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
||||
asmArgList->listAppend(ptxBuilder.newOperand(llWord, constraint));
|
||||
@@ -1100,14 +1091,12 @@ struct LoadOpConversion
|
||||
|
||||
auto [layout, numElems] = getLayout(ptr);
|
||||
|
||||
auto ptrElems =
|
||||
getLLVMElems(ptr, llPtr, layout, getTypeConverter(), rewriter, loc);
|
||||
auto ptrElems = getLLVMElems(ptr, llPtr, layout, rewriter, loc);
|
||||
assert(ptrElems.size() == numElems);
|
||||
|
||||
SmallVector<Value> maskElems;
|
||||
if (llMask) {
|
||||
maskElems =
|
||||
getLLVMElems(mask, llMask, layout, getTypeConverter(), rewriter, loc);
|
||||
maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc);
|
||||
assert(ptrElems.size() == maskElems.size());
|
||||
}
|
||||
|
||||
@@ -1132,8 +1121,7 @@ struct LoadOpConversion
|
||||
splatVal = constAttr.getSplatValue<APInt>().getSExtValue();
|
||||
}
|
||||
|
||||
auto otherElems =
|
||||
getLLVMElems(other, llOther, layout, getTypeConverter(), rewriter, loc);
|
||||
auto otherElems = getLLVMElems(other, llOther, layout, rewriter, loc);
|
||||
|
||||
SmallVector<Value> loadedVals;
|
||||
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
|
||||
@@ -1153,7 +1141,7 @@ struct LoadOpConversion
|
||||
const bool hasL2EvictPolicy = false;
|
||||
|
||||
PTXBuilder ptxBuilder;
|
||||
auto &ld = *ptxBuilder.create<PtxIOInstr>("ld");
|
||||
auto &ld = *ptxBuilder.create<PTXIOInstr>("ld");
|
||||
|
||||
// TODO(Superjomn) Need to check masks before vectorize the load for all
|
||||
// the values share one predicate? Here assume all the mask values are
|
||||
@@ -1198,7 +1186,6 @@ struct LoadOpConversion
|
||||
else
|
||||
ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b");
|
||||
|
||||
SmallVector<Value> others;
|
||||
if (other) {
|
||||
for (size_t ii = 0; ii < nWords; ++ii) {
|
||||
PTXInstr &mov = *ptxBuilder.create<>("mov");
|
||||
@@ -1214,14 +1201,13 @@ struct LoadOpConversion
|
||||
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
|
||||
v = insert_element(vecTy, v, falseVal, sVal);
|
||||
}
|
||||
v = bit_cast(IntegerType::get(getContext(), width), v);
|
||||
v = bitcast(IntegerType::get(getContext(), width), v);
|
||||
|
||||
PTXInstr::Operand *opr{};
|
||||
if (otherIsSplatConstInt) {
|
||||
opr = ptxBuilder.newConstantOperand(splatVal);
|
||||
} else {
|
||||
opr = ptxBuilder.newOperand(v, readConstraint);
|
||||
others.push_back(v);
|
||||
}
|
||||
|
||||
mov(dstsOpr->listGet(ii), opr).predicateNot(pred, "b");
|
||||
@@ -1253,7 +1239,7 @@ struct LoadOpConversion
|
||||
} else {
|
||||
curr = ret;
|
||||
}
|
||||
curr = bit_cast(
|
||||
curr = bitcast(
|
||||
LLVM::getFixedVectorType(valueElemTy, width / valueElemNbits),
|
||||
curr);
|
||||
rets.push_back(curr);
|
||||
@@ -1360,9 +1346,8 @@ struct ExtractSliceOpConversion
|
||||
|
||||
// axis > 0 will result in non-contiguous memory access if the result tensor
|
||||
// is an alias of the source tensor.
|
||||
auto axis =
|
||||
op->getAttrOfType<IntegerAttr>("axis").cast<IntegerAttr>().getInt();
|
||||
assert(axis == 0 && "Only axis=0 is supported for now");
|
||||
auto axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
|
||||
assert(axis == 0 && "extract_slice: Only axis=0 is supported for now");
|
||||
|
||||
// Example:
|
||||
// %dst = extract_slice %src, %index {axis = 0}
|
||||
@@ -1372,12 +1357,11 @@ struct ExtractSliceOpConversion
|
||||
auto base = product<int64_t>(dstTy.getShape());
|
||||
auto baseVal = createIndexAttrConstant(
|
||||
rewriter, loc, getTypeConverter()->getIndexType(), base);
|
||||
Value offset = rewriter.create<LLVM::MulOp>(loc, adaptor.index(), baseVal);
|
||||
Value offset = mul(adaptor.index(), baseVal);
|
||||
|
||||
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
|
||||
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||
Value resultVal =
|
||||
rewriter.create<LLVM::GEPOp>(loc, elemPtrTy, adaptor.src(), offset);
|
||||
Value resultVal = gep(elemPtrTy, adaptor.src(), offset);
|
||||
rewriter.replaceOp(op, resultVal);
|
||||
return success();
|
||||
}
|
||||
@@ -1581,7 +1565,7 @@ void ConvertLayoutOpConversion::processReplica(
|
||||
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
||||
Value ptr = gep(elemPtrTy, smemBase, offset);
|
||||
auto vecTy = vec_ty(llvmElemTy, vec);
|
||||
ptr = bit_cast(ptr_ty(vecTy, 3), ptr);
|
||||
ptr = bitcast(ptr_ty(vecTy, 3), ptr);
|
||||
if (stNotRd) {
|
||||
Value valVec = undef(vecTy);
|
||||
for (unsigned v = 0; v < vec; ++v) {
|
||||
@@ -1614,7 +1598,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
|
||||
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
|
||||
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
||||
smemBase = bit_cast(elemPtrTy, smemBase);
|
||||
smemBase = bitcast(elemPtrTy, smemBase);
|
||||
auto shape = dstTy.getShape();
|
||||
unsigned rank = dstTy.getRank();
|
||||
SmallVector<unsigned> numReplicates(rank);
|
||||
@@ -1732,7 +1716,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
||||
Value minVecVal = idx_val(minVec);
|
||||
Value smemBase = getSharedMemoryBase(loc, rewriter, dst);
|
||||
auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
|
||||
smemBase = bit_cast(elemPtrTy, smemBase);
|
||||
smemBase = bitcast(elemPtrTy, smemBase);
|
||||
unsigned numWordsEachRep = product<unsigned>(wordsInEachRep);
|
||||
SmallVector<Value> wordVecs(numWordsEachRep);
|
||||
for (unsigned i = 0; i < numElems; ++i) {
|
||||
@@ -1783,7 +1767,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
||||
|
||||
// step 3: store
|
||||
Value smemAddr = gep(elemPtrTy, smemBase, offset);
|
||||
smemAddr = bit_cast(ptr_ty(wordTy, 3), smemAddr);
|
||||
smemAddr = bitcast(ptr_ty(wordTy, 3), smemAddr);
|
||||
store(wordVecs[linearWordIdx], smemAddr);
|
||||
}
|
||||
}
|
||||
@@ -2126,7 +2110,7 @@ public:
|
||||
for (int e = 0; e < 4; ++e)
|
||||
i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m],
|
||||
i8Elems[m][e], i32_val(e));
|
||||
i32Elems[m] = bit_cast(i32_ty, i8v4Elems[m]);
|
||||
i32Elems[m] = bitcast(i32_ty, i8v4Elems[m]);
|
||||
}
|
||||
} else { // k first
|
||||
Value offset = i32_val(sOffsetElem);
|
||||
@@ -2144,7 +2128,7 @@ public:
|
||||
for (int e = 0; e < 4; ++e)
|
||||
i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m],
|
||||
i8Elems[m][e], i32_val(e));
|
||||
i32Elems[m] = bit_cast(i32_ty, i8v4Elems[m]);
|
||||
i32Elems[m] = bitcast(i32_ty, i8v4Elems[m]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2628,7 +2612,7 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
|
||||
Type smemPtrTy = helper.getShemPtrTy();
|
||||
for (int i = 0; i < numPtrs; ++i) {
|
||||
ptrs[i] =
|
||||
bit_cast(smemPtrTy, gep(smemPtrTy, llTensor, ValueRange({offs[i]})));
|
||||
bitcast(smemPtrTy, gep(smemPtrTy, llTensor, ValueRange({offs[i]})));
|
||||
}
|
||||
|
||||
bool needTrans = kOrder != order[0];
|
||||
@@ -2777,6 +2761,229 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
struct AsyncWaitOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::AsyncWaitOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::gpu::AsyncWaitOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::gpu::AsyncWaitOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
PTXBuilder ptxBuilder;
|
||||
auto &asyncWaitOp = *ptxBuilder.create<PTXCpAsyncWaitGroupInstr>();
|
||||
auto num = op->getAttrOfType<IntegerAttr>("num").getInt();
|
||||
asyncWaitOp(ptxBuilder.newConstantOperand(num));
|
||||
|
||||
auto ctx = op.getContext();
|
||||
auto loc = op.getLoc();
|
||||
auto voidTy = LLVM::LLVMVoidType::get(ctx);
|
||||
auto ret = ptxBuilder.launch(rewriter, loc, voidTy);
|
||||
|
||||
// Safe to remove the op since it doesn't have any return value.
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct InsertSliceAsyncOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::InsertSliceAsyncOp>,
|
||||
public LoadStoreConversionBase {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::gpu::InsertSliceAsyncOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
InsertSliceAsyncOpConversion(LLVMTypeConverter &converter,
|
||||
const Allocation *allocation, Value smem,
|
||||
AxisInfoAnalysis &axisAnalysisPass,
|
||||
PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::InsertSliceAsyncOp>(
|
||||
converter, allocation, smem, benefit),
|
||||
LoadStoreConversionBase(axisAnalysisPass) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::gpu::InsertSliceAsyncOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// insert_slice_async %src, %dst, %index, %mask, %other
|
||||
auto loc = op.getLoc();
|
||||
Value src = op.src();
|
||||
Value dst = op.dst();
|
||||
Value res = op.result();
|
||||
Value mask = op.mask();
|
||||
Value other = op.other();
|
||||
assert(allocation->getBufferId(res) == Allocation::InvalidBufferId &&
|
||||
"Only support in-place insert_slice_async for now");
|
||||
|
||||
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||
auto resTy = dst.getType().cast<RankedTensorType>();
|
||||
auto resElemTy = resTy.getElementType();
|
||||
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
auto resSharedLayout = resTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
auto srcShape = srcTy.getShape();
|
||||
assert(srcShape.size() == 2 &&
|
||||
"insert_slice_async: Unexpected rank of %src");
|
||||
|
||||
Value llDst = adaptor.dst();
|
||||
Value llSrc = adaptor.src();
|
||||
Value llMask = adaptor.mask();
|
||||
Value llOther = adaptor.other();
|
||||
Value llIndex = adaptor.index();
|
||||
|
||||
// %src
|
||||
auto srcElems = getLLVMElems(src, llSrc, srcBlockedLayout, rewriter, loc);
|
||||
|
||||
// %dst
|
||||
auto axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
|
||||
assert(axis == 0 && "insert_slice_async: Only axis=0 is supported for now");
|
||||
auto dstBase = createIndexAttrConstant(rewriter, loc,
|
||||
getTypeConverter()->getIndexType(),
|
||||
product<int64_t>(resTy.getShape()));
|
||||
Value offset = mul(llIndex, dstBase);
|
||||
auto dstPtrTy = LLVM::LLVMPointerType::get(
|
||||
getTypeConverter()->convertType(resTy.getElementType()), 3);
|
||||
Value dstPtrBase = gep(dstPtrTy, llDst, offset);
|
||||
|
||||
// %mask
|
||||
SmallVector<Value> maskElems;
|
||||
if (llMask) {
|
||||
maskElems = getLLVMElems(mask, llMask, srcBlockedLayout, rewriter, loc);
|
||||
assert(srcElems.size() == maskElems.size());
|
||||
}
|
||||
|
||||
// %other
|
||||
SmallVector<Value> otherElems;
|
||||
if (llOther) {
|
||||
// TODO(Keren): support "other" tensor.
|
||||
// It's not necessary for now because the pipeline pass will skip
|
||||
// generating insert_slice_async if the load op has any "other" tensor.
|
||||
assert(false && "insert_slice_async: Other value not supported yet");
|
||||
otherElems =
|
||||
getLLVMElems(other, llOther, srcBlockedLayout, rewriter, loc);
|
||||
assert(srcElems.size() == otherElems.size());
|
||||
}
|
||||
|
||||
unsigned inVec = getVectorizeSize(src, srcBlockedLayout);
|
||||
unsigned outVec = resSharedLayout.getVec();
|
||||
unsigned minVec = std::min(outVec, inVec);
|
||||
unsigned numElems = getElemsPerThread(srcBlockedLayout, srcShape);
|
||||
unsigned perPhase = resSharedLayout.getPerPhase();
|
||||
unsigned maxPhase = resSharedLayout.getMaxPhase();
|
||||
auto sizePerThread = srcBlockedLayout.getSizePerThread();
|
||||
auto threadsPerWarp = srcBlockedLayout.getThreadsPerWarp();
|
||||
auto warpsPerCTA = srcBlockedLayout.getWarpsPerCTA();
|
||||
auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout);
|
||||
|
||||
auto inOrder = srcBlockedLayout.getOrder();
|
||||
auto outOrder = resSharedLayout.getOrder();
|
||||
// If perPhase * maxPhase > threadsPerCTA, we need to swizzle over elements
|
||||
// across phases.
|
||||
// If perPhase * maxPhase == threadsPerCTA, swizzle is not allowd
|
||||
auto numSwizzleRows = std::max<unsigned>(
|
||||
(perPhase * maxPhase) / threadsPerCTA[inOrder[1]], 1);
|
||||
// A sharedLayout encoding has a "vec" parameter.
|
||||
// On the column dimension, if inVec > outVec, it means we have to divide
|
||||
// single vector read into multiple ones
|
||||
auto numVecCols = std::max<unsigned>(inVec / outVec, 1);
|
||||
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcBlockedLayout, srcShape);
|
||||
// <<tileVecIdxRow, tileVecIdxCol>, TileOffset>
|
||||
DenseMap<std::pair<unsigned, unsigned>, Value> tileOffsetMap;
|
||||
for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) {
|
||||
// minVec = 2, inVec = 4, outVec = 2
|
||||
// baseOffsetCol = 0 baseOffsetCol = 0
|
||||
// tileVecIdxCol = 0 tileVecIdxCol = 1
|
||||
// -/\- -/\-
|
||||
// [|x x| |x x| x x x x x]
|
||||
// [|x x| |x x| x x x x x]
|
||||
// baseOffsetRow [|x x| |x x| x x x x x]
|
||||
// [|x x| |x x| x x x x x]
|
||||
auto vecIdx = elemIdx / minVec;
|
||||
auto vecIdxCol = vecIdx % (sizePerThread[inOrder[0]] / minVec);
|
||||
auto vecIdxRow = vecIdx / (sizePerThread[inOrder[0]] / minVec);
|
||||
auto baseOffsetCol =
|
||||
vecIdxCol / numVecCols * numVecCols * threadsPerCTA[inOrder[0]];
|
||||
auto baseOffsetRow = vecIdxRow / numSwizzleRows * numSwizzleRows *
|
||||
threadsPerCTA[inOrder[1]];
|
||||
auto baseOffset = (baseOffsetRow * srcShape[inOrder[0]] + baseOffsetCol);
|
||||
auto tileVecIdxCol = vecIdxCol % numVecCols;
|
||||
auto tileVecIdxRow = vecIdxRow % numSwizzleRows;
|
||||
|
||||
if (!tileOffsetMap.count({tileVecIdxRow, tileVecIdxCol})) {
|
||||
// Swizzling
|
||||
// Since the swizzling index is related to outVec, and we know minVec
|
||||
// already, inVec doesn't matter
|
||||
//
|
||||
// (Numbers represent row indices)
|
||||
// Example1:
|
||||
// outVec = 2, inVec = 2, minVec = 2
|
||||
// outVec = 2, inVec = 4, minVec = 2
|
||||
// | [1 2] [3 4] ... [15 16] |
|
||||
// | [3 4] [5 6] ... [1 2] |
|
||||
// Example2:
|
||||
// outVec = 4, inVec = 2, minVec = 2
|
||||
// | [1 2 3 4] [5 6 7 8] ... [13 14 15 16] |
|
||||
// | [5 6 7 8] [9 10 11 12] ... [1 2 3 4] |
|
||||
auto srcIdx = srcIndices[tileVecIdxRow * sizePerThread[inOrder[0]]];
|
||||
Value phase = urem(udiv(srcIdx[inOrder[1]], i32_val(perPhase)),
|
||||
i32_val(maxPhase));
|
||||
Value rowOffset =
|
||||
mul(srcIdx[inOrder[1]], i32_val(srcShape[inOrder[0]]));
|
||||
Value colOffset =
|
||||
add(srcIdx[inOrder[0]], i32_val(tileVecIdxCol * minVec));
|
||||
Value swizzleIdx = udiv(colOffset, i32_val(outVec));
|
||||
Value swizzleColOffset =
|
||||
add(mul(xor_(swizzleIdx, phase), i32_val(outVec)),
|
||||
urem(colOffset, i32_val(outVec)));
|
||||
Value tileOffset = add(rowOffset, swizzleColOffset);
|
||||
tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}] =
|
||||
gep(dstPtrTy, dstPtrBase, tileOffset);
|
||||
}
|
||||
|
||||
// 16 * 8 = 128bits
|
||||
auto maxBitWidth =
|
||||
std::max<unsigned>(128, resElemTy.getIntOrFloatBitWidth());
|
||||
auto vecBitWidth = resElemTy.getIntOrFloatBitWidth() * minVec;
|
||||
auto bitWidth = std::min<unsigned>(maxBitWidth, vecBitWidth);
|
||||
auto numWords = vecBitWidth / bitWidth;
|
||||
auto numWordElems = bitWidth / resElemTy.getIntOrFloatBitWidth();
|
||||
|
||||
// XXX(Keren): Tune CG and CA here.
|
||||
CacheModifier srcCacheModifier =
|
||||
bitWidth == 128 ? CacheModifier::CG : CacheModifier::CA;
|
||||
assert(bitWidth == 128 || bitWidth == 64 || bitWidth == 32);
|
||||
|
||||
for (int wordIdx = 0; wordIdx < numWords; ++wordIdx) {
|
||||
PTXBuilder ptxBuilder;
|
||||
auto ©AsyncOp = *ptxBuilder.create<PTXCpAsyncLoadInstr>(
|
||||
srcCacheModifier, op.evict());
|
||||
|
||||
auto tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}];
|
||||
auto *dstOperand =
|
||||
ptxBuilder.newAddrOperand(tileOffset, "r", baseOffset);
|
||||
auto *srcOperand = ptxBuilder.newAddrOperand(srcElems[vecIdx], "l");
|
||||
auto *copySize = ptxBuilder.newConstantOperand(bitWidth);
|
||||
auto *srcSize = copySize;
|
||||
if (op.mask()) {
|
||||
// We don't use predicate in this case, setting src-size to 0
|
||||
// if there's any mask. cp.async will automatically fill the
|
||||
// remaining slots with 0 if cp-size > src-size.
|
||||
// XXX(Keren): Always assume other = 0 for now.
|
||||
auto selectOp = select(maskElems[vecIdx + wordIdx * numWordElems],
|
||||
i32_val(bitWidth), i32_val(0));
|
||||
srcSize = ptxBuilder.newOperand(selectOp, "r");
|
||||
}
|
||||
copyAsyncOp(dstOperand, srcOperand, copySize, srcSize);
|
||||
ptxBuilder.launch(rewriter, loc, LLVM::LLVMVoidType::get(getContext()));
|
||||
}
|
||||
}
|
||||
|
||||
PTXBuilder ptxBuilder;
|
||||
ptxBuilder.create<PTXCpAsyncCommitGroupInstr>()->operator()();
|
||||
auto ret =
|
||||
ptxBuilder.launch(rewriter, loc, LLVM::LLVMVoidType::get(getContext()));
|
||||
rewriter.replaceOp(op, ret);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns, int numWarps,
|
||||
AxisInfoAnalysis &axisInfoAnalysis,
|
||||
@@ -2786,6 +2993,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
patterns.add<AllocTensorOpConversion>(typeConverter, allocation, smem,
|
||||
benefit);
|
||||
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
|
||||
patterns.add<AsyncWaitOpConversion>(typeConverter, benefit);
|
||||
patterns.add<BinaryOpConversion<arith::AddIOp, LLVM::AddOp>>(typeConverter,
|
||||
benefit);
|
||||
patterns.add<BinaryOpConversion<arith::AddFOp, LLVM::FAddOp>>(typeConverter,
|
||||
@@ -2800,6 +3008,8 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
patterns.add<ExtractSliceOpConversion>(typeConverter, allocation, smem,
|
||||
benefit);
|
||||
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
|
||||
patterns.add<InsertSliceAsyncOpConversion>(typeConverter, allocation, smem,
|
||||
axisInfoAnalysis, benefit);
|
||||
patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||
patterns.add<MakeRangeOpConversion>(typeConverter, benefit);
|
||||
patterns.add<ReturnOpConversion>(typeConverter, benefit);
|
||||
|
@@ -72,6 +72,21 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout) {
|
||||
SmallVector<unsigned> threads;
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
for (int d = 0, n = blockedLayout.getOrder().size(); d < n; ++d)
|
||||
threads.push_back(blockedLayout.getThreadsPerWarp()[d] *
|
||||
blockedLayout.getWarpsPerCTA()[d]);
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
assert(0 && "Unimplemented usage of MmaEncodingAttr");
|
||||
} else {
|
||||
assert(0 && "Unimplemented usage of getShapePerCTA");
|
||||
}
|
||||
|
||||
return threads;
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
|
||||
SmallVector<unsigned> shape;
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
|
Reference in New Issue
Block a user