[Triton-MLIR][Backend] Fix the order in linear/delinear and a few bugs in reduce conversion (#851)

1, fix the order in linearize/delinearize, which fix the error of order
in emitIndices;
2, fix the selecting of fast implementation in reduce codegen;
3, fix the redundant barrier in reduce codegen;
4, fix the index mapping of the second round of warp_shuffle in shuffle
version of reduce codegen.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
This commit is contained in:
goostavz
2022-11-09 02:10:09 +08:00
committed by GitHub
parent 303790da88
commit 080b4addf8
4 changed files with 282 additions and 246 deletions

View File

@@ -83,6 +83,11 @@ static Value createLLVMIntegerConstant(OpBuilder &builder, Location loc,
} // namespace
// A helper function for using printf in LLVM conversion.
void llPrintf(StringRef msg, ValueRange args,
ConversionPatternRewriter &rewriter);
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive//
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
#define zext(...) rewriter.create<LLVM::ZExtOp>(loc, __VA_ARGS__)
#define udiv(...) rewriter.create<LLVM::UDivOp>(loc, __VA_ARGS__)
@@ -338,6 +343,7 @@ Value getStructFromElements(Location loc, ValueRange resultVals,
return llvmStruct;
}
// Delinearize on compile-time consts, assuming the order is [n, .. 2, 1, 0]
template <typename T>
static SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape) {
// shape: {a, b, c, d} -> accMul: {b*c*d, c*d, d, 1}
@@ -355,6 +361,7 @@ static SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape) {
return multiDimIndex;
}
// Linearize on compile-time consts, assuming the order is [n, .. 2, 1, 0]
template <typename T>
static T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
assert(multiDimIndex.size() == shape.size());
@@ -510,12 +517,12 @@ public:
multiDim[0] = linear;
} else {
Value remained = linear;
for (auto &&en : llvm::enumerate(llvm::reverse(shape.drop_front()))) {
for (auto &&en : llvm::enumerate(shape.drop_back())) {
Value dimSize = idx_val(en.value());
multiDim[rank - 1 - en.index()] = urem(remained, dimSize);
multiDim[en.index()] = urem(remained, dimSize);
remained = udiv(remained, dimSize);
}
multiDim[0] = remained;
multiDim[rank - 1] = remained;
}
return multiDim;
}
@@ -525,9 +532,9 @@ public:
int rank = multiDim.size();
Value linear = idx_val(0);
if (rank > 0) {
linear = multiDim.front();
linear = multiDim.back();
for (auto [dim, shape] :
llvm::zip(multiDim.drop_front(), shape.drop_front())) {
llvm::reverse(llvm::zip(multiDim.drop_back(), shape.drop_back()))) {
Value dimSize = idx_val(shape);
linear = add(mul(linear, dimSize), dim);
}
@@ -566,6 +573,7 @@ public:
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
SmallVector<Value> multiDimThreadId =
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
SmallVector<Value> multiDimBase(rank);
for (unsigned k = 0; k < rank; ++k) {
// Wrap around multiDimWarpId/multiDimThreadId incase
@@ -1362,7 +1370,9 @@ private:
LogicalResult
ReduceOpConversion::matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (op.axis() == 1) // FIXME(Qingyi): The fastest-changing dimension
auto srcTy = op.operand().getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
if (op.axis() == srcLayout.getOrder()[0])
return matchAndRewriteFast(op, adaptor, rewriter);
return matchAndRewriteBasic(op, adaptor, rewriter);
}
@@ -1444,6 +1454,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
auto srcTy = op.operand().getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
auto srcOrd = srcLayout.getOrder();
auto srcShape = srcTy.getShape();
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
@@ -1487,7 +1498,9 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
SmallVector<Value> writeIdx = indices[key];
writeIdx[axis] = udiv(writeIdx[axis], sizePerThread);
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape);
Value writeOffset =
linearize(rewriter, loc, reorder<Value>(writeIdx, srcOrd),
reorder<unsigned>(smemShape, srcOrd));
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
store(acc, writePtr);
@@ -1495,8 +1508,11 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
for (int N = smemShape[axis] / 2; N > 0; N >>= 1) {
readIdx[axis] = ints[N];
Value readMask = icmp_slt(writeIdx[axis], ints[N]);
Value readOffset = select(
readMask, linearize(rewriter, loc, readIdx, smemShape), ints[0]);
Value readOffset =
select(readMask,
linearize(rewriter, loc, reorder<Value>(readIdx, srcOrd),
reorder<unsigned>(smemShape, srcOrd)),
ints[0]);
Value readPtr = gep(elemPtrTy, writePtr, readOffset);
barrier();
accumulate(rewriter, loc, op.redOp(), acc, load(readPtr), false);
@@ -1519,7 +1535,9 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
for (unsigned i = 0; i < resultElems; ++i) {
SmallVector<Value> readIdx = resultIndices[i];
readIdx.insert(readIdx.begin() + axis, ints[0]);
Value readOffset = linearize(rewriter, loc, readIdx, smemShape);
Value readOffset =
linearize(rewriter, loc, reorder<Value>(readIdx, srcOrd),
reorder<unsigned>(smemShape, srcOrd));
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
resultVals[i] = load(readPtr);
}
@@ -1548,6 +1566,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
auto srcTy = op.operand().getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
auto srcShape = srcTy.getShape();
auto srcRank = srcTy.getRank();
auto threadsPerWarp = srcLayout.getThreadsPerWarp();
auto warpsPerCTA = srcLayout.getWarpsPerCTA();
@@ -1592,6 +1611,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
Value laneIdAxis = multiDimLaneId[axis];
Value warpIdAxis = multiDimWarpId[axis];
@@ -1609,56 +1629,77 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
}
if (sizeInterWarps == 1) {
SmallVector<Value> writeIdx = indices[key];
writeIdx[axis] = zero;
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape);
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
storeShared(rewriter, loc, writePtr, acc, laneZero);
} else {
SmallVector<Value> writeIdx = indices[key];
writeIdx[axis] =
warpIdAxis; // axis must be the fastest-changing dimension
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape);
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
storeShared(rewriter, loc, writePtr, acc, laneZero);
barrier();
SmallVector<Value> writeIdx = indices[key];
writeIdx[axis] = (sizeInterWarps == 1) ? zero : warpIdAxis;
Value writeOffset =
linearize(rewriter, loc, reorder<Value>(writeIdx, order),
reorder<unsigned>(smemShape, order));
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
storeShared(rewriter, loc, writePtr, acc, laneZero);
}
SmallVector<Value> readIdx = writeIdx;
readIdx[axis] = urem(laneId, i32_val(sizeInterWarps));
Value readOffset = linearize(rewriter, loc, readIdx, smemShape);
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
acc = load(readPtr);
barrier();
// reduce across warps
for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) {
Value shfl = shflSync(rewriter, loc, acc, N);
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
}
// the second round of shuffle reduction
// now the problem size: sizeInterWarps, s1, s2, .. , sn =>
// 1, s1, s2, .. , sn
// where sizeInterWarps is 2^m
//
// each thread needs to process:
// elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads
unsigned elems = product<unsigned>(smemShape);
unsigned numThreads = product<unsigned>(srcLayout.getWarpsPerCTA()) * 32;
unsigned elemsPerThread = std::max<unsigned>(elems / numThreads, 1);
Value readOffset = threadId;
for (unsigned round = 0; round < elemsPerThread; ++round) {
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
Value acc = load(readPtr);
writeIdx[axis] = zero;
writeOffset = linearize(rewriter, loc, writeIdx, smemShape);
writePtr = gep(elemPtrTy, smemBase, writeOffset);
storeShared(rewriter, loc, writePtr, acc, and_(laneZero, warpZero));
for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) {
Value shfl = shflSync(rewriter, loc, acc, N);
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
}
Value writeOffset = udiv(readOffset, i32_val(sizeInterWarps));
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
Value threadIsNeeded = icmp_slt(threadId, i32_val(elems));
Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps));
Value laneIdModSizeInterWarpsIsZero =
icmp_eq(laneIdModSizeInterWarps, zero);
storeShared(rewriter, loc, writePtr, acc,
and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero));
if (round != elemsPerThread - 1) {
readOffset = add(readOffset, i32_val(numThreads));
}
}
// We could avoid this barrier in some of the layouts, however this is not
// the general case. TODO: optimize the barrier incase the layouts are
// accepted.
barrier();
// set output values
if (auto resultTy = op.getType().dyn_cast<RankedTensorType>()) {
// nd-tensor where n >= 1
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
auto resultShape = resultTy.getShape();
SmallVector<unsigned> resultOrd;
for (auto ord : order) {
if (ord != 0)
resultOrd.push_back(ord - 1);
}
unsigned resultElems = getElemsPerThread(resultTy);
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultShape);
assert(resultIndices.size() == resultElems);
barrier();
SmallVector<Value> resultVals(resultElems);
for (size_t i = 0; i < resultElems; ++i) {
SmallVector<Value> readIdx = resultIndices[i];
readIdx.insert(readIdx.begin() + axis, i32_val(0));
Value readOffset = linearize(rewriter, loc, readIdx, smemShape);
Value readOffset =
linearize(rewriter, loc, reorder<Value>(readIdx, resultOrd),
reorder<int64_t, unsigned>(resultShape, resultOrd));
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
resultVals[i] = load(readPtr);
}
@@ -1670,7 +1711,6 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
rewriter.replaceOp(op, ret);
} else {
// 0d-tensor -> scalar
barrier();
Value resultVal = load(smemBase);
rewriter.replaceOp(op, resultVal);
}
@@ -1707,6 +1747,191 @@ struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
}
};
struct PrintfOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::PrintfOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::PrintfOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::PrintfOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
SmallVector<Value, 16> operands;
for (auto operand : adaptor.getOperands()) {
auto sub_operands = this->getElementsFromStruct(loc, operand, rewriter);
for (auto elem : sub_operands) {
operands.push_back(elem);
}
}
std::string formatStr;
llvm::raw_string_ostream os(formatStr);
os << op.prefix();
if (operands.size() > 0) {
os << getFormatSubstr(operands[0]);
}
for (size_t i = 1; i < operands.size(); ++i) {
os << ", " << getFormatSubstr(operands[i]);
}
llPrintf(formatStr, operands, rewriter);
rewriter.eraseOp(op);
return success();
}
// get format specific for each input value
// currently support pointer, i8, i16, i32, i64, f16, bf16, f32, f64
std::string getFormatSubstr(Value value) const {
Type type = value.getType();
unsigned width = type.getIntOrFloatBitWidth();
if (type.isa<LLVM::LLVMPointerType>()) {
return "%p";
} else if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) {
return "%f";
} else if (type.isSignedInteger()) {
return "%i";
} else if (type.isUnsignedInteger() || type.isSignlessInteger()) {
return "%u";
}
assert(false && "not supported type");
}
// declare vprintf(i8*, i8*) as external function
static LLVM::LLVMFuncOp
getVprintfDeclaration(ConversionPatternRewriter &rewriter) {
auto moduleOp =
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
StringRef funcName("vprintf");
Operation *funcOp = moduleOp.lookupSymbol(funcName);
if (funcOp)
return cast<LLVM::LLVMFuncOp>(*funcOp);
auto *context = rewriter.getContext();
SmallVector<Type> argsType{ptr_ty(IntegerType::get(context, 8)),
ptr_ty(IntegerType::get(context, 8))};
auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType);
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
return rewriter.create<LLVM::LLVMFuncOp>(UnknownLoc::get(context), funcName,
funcType);
}
// extend integer to int32, extend float to float64
// this comes from vprintf alignment requirements.
static std::pair<Type, Value>
promoteValue(ConversionPatternRewriter &rewriter, Value value) {
auto *context = rewriter.getContext();
auto type = value.getType();
type.dump();
unsigned width = type.getIntOrFloatBitWidth();
Value newOp = value;
Type newType = type;
bool bUnsigned = type.isUnsignedInteger();
if (type.isIntOrIndex() && width < 32) {
if (bUnsigned) {
newType = ui32_ty;
newOp = rewriter.create<LLVM::ZExtOp>(UnknownLoc::get(context), newType,
value);
} else {
newType = i32_ty;
newOp = rewriter.create<LLVM::SExtOp>(UnknownLoc::get(context), newType,
value);
}
} else if (type.isBF16() || type.isF16() || type.isF32()) {
newType = f64_ty;
newOp = rewriter.create<LLVM::FPExtOp>(UnknownLoc::get(context), newType,
value);
}
return {newType, newOp};
}
static void llPrintf(StringRef msg, ValueRange args,
ConversionPatternRewriter &rewriter) {
static const char formatStringPrefix[] = "printfFormat_";
assert(!msg.empty() && "printf with empty string not support");
Type int8Ptr = ptr_ty(i8_ty);
auto *context = rewriter.getContext();
auto moduleOp =
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
auto funcOp = getVprintfDeclaration(rewriter);
Value one = rewriter.create<LLVM::ConstantOp>(
UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(1));
Value zero = rewriter.create<LLVM::ConstantOp>(
UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(0));
unsigned stringNumber = 0;
SmallString<16> stringConstName;
do {
stringConstName.clear();
(formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
} while (moduleOp.lookupSymbol(stringConstName));
llvm::SmallString<64> formatString(msg);
formatString.push_back('\n');
formatString.push_back('\0');
size_t formatStringSize = formatString.size_in_bytes();
auto globalType = LLVM::LLVMArrayType::get(i8_ty, formatStringSize);
LLVM::GlobalOp global;
{
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
global = rewriter.create<LLVM::GlobalOp>(
UnknownLoc::get(context), globalType,
/*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
rewriter.getStringAttr(formatString));
}
Value globalPtr =
rewriter.create<LLVM::AddressOfOp>(UnknownLoc::get(context), global);
Value stringStart =
rewriter.create<LLVM::GEPOp>(UnknownLoc::get(context), int8Ptr,
globalPtr, mlir::ValueRange({zero, zero}));
Value bufferPtr =
rewriter.create<LLVM::NullOp>(UnknownLoc::get(context), int8Ptr);
SmallVector<Value, 16> newArgs;
if (args.size() >= 1) {
SmallVector<Type> argTypes;
for (auto arg : args) {
Type newType;
Value newArg;
std::tie(newType, newArg) = promoteValue(rewriter, arg);
argTypes.push_back(newType);
newArgs.push_back(newArg);
}
Type structTy = LLVM::LLVMStructType::getLiteral(context, argTypes);
auto allocated = rewriter.create<LLVM::AllocaOp>(UnknownLoc::get(context),
ptr_ty(structTy), one,
/*alignment=*/0);
for (const auto &entry : llvm::enumerate(newArgs)) {
auto index = rewriter.create<LLVM::ConstantOp>(
UnknownLoc::get(context), i32_ty,
rewriter.getI32IntegerAttr(entry.index()));
auto fieldPtr = rewriter.create<LLVM::GEPOp>(
UnknownLoc::get(context), ptr_ty(argTypes[entry.index()]),
allocated, ArrayRef<Value>{zero, index});
rewriter.create<LLVM::StoreOp>(UnknownLoc::get(context), entry.value(),
fieldPtr);
}
bufferPtr = rewriter.create<LLVM::BitcastOp>(UnknownLoc::get(context),
int8Ptr, allocated);
}
ValueRange operands{stringStart, bufferPtr};
rewriter.create<LLVM::CallOp>(UnknownLoc::get(context), funcOp, operands);
}
};
struct MakeRangeOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp> {
@@ -2070,17 +2295,6 @@ public:
}
private:
template <typename T>
SmallVector<T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) const {
size_t rank = order.size();
assert(input.size() == rank);
SmallVector<T> result(rank);
for (auto it : llvm::enumerate(order)) {
result[rank - 1 - it.value()] = input[it.index()];
}
return result;
};
// shared memory rd/st for blocked or mma layout with data padding
void processReplica(Location loc, ConversionPatternRewriter &rewriter,
bool stNotRd, RankedTensorType type,
@@ -4483,7 +4697,7 @@ struct InsertSliceAsyncOpConversion
auto numVecCols = std::max<unsigned>(inVec / outVec, 1);
auto srcIndices = emitIndices(loc, rewriter, srcBlockedLayout, srcShape);
// <<tileVecIdxRow, tileVecIdxCol>, TileOffset>
// <<tileVecIdxRow, tileVecIdxCol>, TileOffset>
DenseMap<std::pair<unsigned, unsigned>, Value> tileOffsetMap;
for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) {
// minVec = 2, inVec = 4, outVec = 2
@@ -4674,190 +4888,6 @@ struct FDivOpConversion
}
};
struct PrintfOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::PrintfOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::PrintfOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::PrintfOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
SmallVector<Value, 16> operands;
for (auto operand : adaptor.getOperands()) {
auto sub_operands = this->getElementsFromStruct(loc, operand, rewriter);
for (auto elem : sub_operands) {
operands.push_back(elem);
}
}
std::string formatStr;
llvm::raw_string_ostream os(formatStr);
os << op.prefix();
if (operands.size() > 0) {
os << getFormatSubstr(operands[0]);
}
for (size_t i = 1; i < operands.size(); ++i) {
os << ", " << getFormatSubstr(operands[i]);
}
llPrintf(formatStr, operands, rewriter);
rewriter.eraseOp(op);
return success();
}
// get format specific for each input value
// currently support pointer, i8, i16, i32, i64, f16, bf16, f32, f64
std::string getFormatSubstr(Value value) const {
Type type = value.getType();
unsigned width = type.getIntOrFloatBitWidth();
if (type.isa<LLVM::LLVMPointerType>()) {
return "%p";
} else if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) {
return "%f";
} else if (type.isSignedInteger()) {
return "%i";
} else if (type.isUnsignedInteger() || type.isSignlessInteger()) {
return "%u";
}
assert(false && "not supported type");
}
// declare vprintf(i8*, i8*) as external function
LLVM::LLVMFuncOp
getVprintfDeclaration(ConversionPatternRewriter &rewriter) const {
auto moduleOp =
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
StringRef funcName("vprintf");
Operation *funcOp = moduleOp.lookupSymbol(funcName);
if (funcOp)
return cast<LLVM::LLVMFuncOp>(*funcOp);
auto *context = rewriter.getContext();
SmallVector<Type> argsType{ptr_ty(IntegerType::get(context, 8)),
ptr_ty(IntegerType::get(context, 8))};
auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType);
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
return rewriter.create<LLVM::LLVMFuncOp>(UnknownLoc::get(context), funcName,
funcType);
}
// extend integer to int32, extend float to float64
// this comes from vprintf alignment requirements.
std::pair<Type, Value> promoteValue(ConversionPatternRewriter &rewriter,
Value value) const {
auto *context = rewriter.getContext();
auto type = value.getType();
unsigned width = type.getIntOrFloatBitWidth();
Value newOp = value;
Type newType = type;
bool bUnsigned = type.isUnsignedInteger();
if (type.isIntOrIndex() && width < 32) {
if (bUnsigned) {
newType = ui32_ty;
newOp = rewriter.create<LLVM::ZExtOp>(UnknownLoc::get(context), newType,
value);
} else {
newType = i32_ty;
newOp = rewriter.create<LLVM::SExtOp>(UnknownLoc::get(context), newType,
value);
}
} else if (type.isBF16() || type.isF16() || type.isF32()) {
newType = f64_ty;
newOp = rewriter.create<LLVM::FPExtOp>(UnknownLoc::get(context), newType,
value);
}
return {newType, newOp};
}
void llPrintf(StringRef msg, ValueRange args,
ConversionPatternRewriter &rewriter) const {
static const char formatStringPrefix[] = "printfFormat_";
assert(!msg.empty() && "printf with empty string not support");
Type int8Ptr = ptr_ty(i8_ty);
auto *context = rewriter.getContext();
auto moduleOp =
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
auto funcOp = getVprintfDeclaration(rewriter);
Value one = rewriter.create<LLVM::ConstantOp>(
UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(1));
Value zero = rewriter.create<LLVM::ConstantOp>(
UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(0));
unsigned stringNumber = 0;
SmallString<16> stringConstName;
do {
stringConstName.clear();
(formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
} while (moduleOp.lookupSymbol(stringConstName));
llvm::SmallString<64> formatString(msg);
formatString.push_back('\n');
formatString.push_back('\0');
size_t formatStringSize = formatString.size_in_bytes();
auto globalType = LLVM::LLVMArrayType::get(i8_ty, formatStringSize);
LLVM::GlobalOp global;
{
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
global = rewriter.create<LLVM::GlobalOp>(
UnknownLoc::get(context), globalType,
/*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
rewriter.getStringAttr(formatString));
}
Value globalPtr =
rewriter.create<LLVM::AddressOfOp>(UnknownLoc::get(context), global);
Value stringStart =
rewriter.create<LLVM::GEPOp>(UnknownLoc::get(context), int8Ptr,
globalPtr, mlir::ValueRange({zero, zero}));
Value bufferPtr =
rewriter.create<LLVM::NullOp>(UnknownLoc::get(context), int8Ptr);
SmallVector<Value, 16> newArgs;
if (args.size() >= 1) {
SmallVector<Type> argTypes;
for (auto arg : args) {
Type newType;
Value newArg;
std::tie(newType, newArg) = promoteValue(rewriter, arg);
argTypes.push_back(newType);
newArgs.push_back(newArg);
}
Type structTy = LLVM::LLVMStructType::getLiteral(context, argTypes);
auto allocated = rewriter.create<LLVM::AllocaOp>(UnknownLoc::get(context),
ptr_ty(structTy), one,
/*alignment=*/0);
for (const auto &entry : llvm::enumerate(newArgs)) {
auto index = rewriter.create<LLVM::ConstantOp>(
UnknownLoc::get(context), i32_ty,
rewriter.getI32IntegerAttr(entry.index()));
auto fieldPtr = rewriter.create<LLVM::GEPOp>(
UnknownLoc::get(context), ptr_ty(argTypes[entry.index()]),
allocated, ArrayRef<Value>{zero, index});
rewriter.create<LLVM::StoreOp>(UnknownLoc::get(context), entry.value(),
fieldPtr);
}
bufferPtr = rewriter.create<LLVM::BitcastOp>(UnknownLoc::get(context),
int8Ptr, allocated);
}
ValueRange operands{stringStart, bufferPtr};
rewriter.create<LLVM::CallOp>(UnknownLoc::get(context), funcOp, operands);
}
};
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
@@ -5062,6 +5092,15 @@ void ConvertTritonGPUToLLVM::initSharedMemory(
namespace mlir {
namespace LLVM {
void llPrintf(StringRef msg, ValueRange args,
ConversionPatternRewriter &rewriter) {
PrintfOpConversion::llPrintf(msg, args, rewriter);
}
} // namespace LLVM
TritonLLVMConversionTarget::TritonLLVMConversionTarget(
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter)
: ConversionTarget(ctx) {