[Triton-MLIR][Backend] Fix the order in linear/delinear and a few bugs in reduce conversion (#851)
1, fix the order in linearize/delinearize, which fix the error of order in emitIndices; 2, fix the selecting of fast implementation in reduce codegen; 3, fix the redundant barrier in reduce codegen; 4, fix the index mapping of the second round of warp_shuffle in shuffle version of reduce codegen. Co-authored-by: Keren Zhou <kerenzhou@openai.com>
This commit is contained in:
@@ -83,6 +83,11 @@ static Value createLLVMIntegerConstant(OpBuilder &builder, Location loc,
|
||||
|
||||
} // namespace
|
||||
|
||||
// A helper function for using printf in LLVM conversion.
|
||||
void llPrintf(StringRef msg, ValueRange args,
|
||||
ConversionPatternRewriter &rewriter);
|
||||
|
||||
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive//
|
||||
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
|
||||
#define zext(...) rewriter.create<LLVM::ZExtOp>(loc, __VA_ARGS__)
|
||||
#define udiv(...) rewriter.create<LLVM::UDivOp>(loc, __VA_ARGS__)
|
||||
@@ -338,6 +343,7 @@ Value getStructFromElements(Location loc, ValueRange resultVals,
|
||||
return llvmStruct;
|
||||
}
|
||||
|
||||
// Delinearize on compile-time consts, assuming the order is [n, .. 2, 1, 0]
|
||||
template <typename T>
|
||||
static SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape) {
|
||||
// shape: {a, b, c, d} -> accMul: {b*c*d, c*d, d, 1}
|
||||
@@ -355,6 +361,7 @@ static SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape) {
|
||||
return multiDimIndex;
|
||||
}
|
||||
|
||||
// Linearize on compile-time consts, assuming the order is [n, .. 2, 1, 0]
|
||||
template <typename T>
|
||||
static T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
|
||||
assert(multiDimIndex.size() == shape.size());
|
||||
@@ -510,12 +517,12 @@ public:
|
||||
multiDim[0] = linear;
|
||||
} else {
|
||||
Value remained = linear;
|
||||
for (auto &&en : llvm::enumerate(llvm::reverse(shape.drop_front()))) {
|
||||
for (auto &&en : llvm::enumerate(shape.drop_back())) {
|
||||
Value dimSize = idx_val(en.value());
|
||||
multiDim[rank - 1 - en.index()] = urem(remained, dimSize);
|
||||
multiDim[en.index()] = urem(remained, dimSize);
|
||||
remained = udiv(remained, dimSize);
|
||||
}
|
||||
multiDim[0] = remained;
|
||||
multiDim[rank - 1] = remained;
|
||||
}
|
||||
return multiDim;
|
||||
}
|
||||
@@ -525,9 +532,9 @@ public:
|
||||
int rank = multiDim.size();
|
||||
Value linear = idx_val(0);
|
||||
if (rank > 0) {
|
||||
linear = multiDim.front();
|
||||
linear = multiDim.back();
|
||||
for (auto [dim, shape] :
|
||||
llvm::zip(multiDim.drop_front(), shape.drop_front())) {
|
||||
llvm::reverse(llvm::zip(multiDim.drop_back(), shape.drop_back()))) {
|
||||
Value dimSize = idx_val(shape);
|
||||
linear = add(mul(linear, dimSize), dim);
|
||||
}
|
||||
@@ -566,6 +573,7 @@ public:
|
||||
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
|
||||
SmallVector<Value> multiDimThreadId =
|
||||
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
|
||||
|
||||
SmallVector<Value> multiDimBase(rank);
|
||||
for (unsigned k = 0; k < rank; ++k) {
|
||||
// Wrap around multiDimWarpId/multiDimThreadId incase
|
||||
@@ -1362,7 +1370,9 @@ private:
|
||||
LogicalResult
|
||||
ReduceOpConversion::matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
if (op.axis() == 1) // FIXME(Qingyi): The fastest-changing dimension
|
||||
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
||||
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
if (op.axis() == srcLayout.getOrder()[0])
|
||||
return matchAndRewriteFast(op, adaptor, rewriter);
|
||||
return matchAndRewriteBasic(op, adaptor, rewriter);
|
||||
}
|
||||
@@ -1444,6 +1454,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
||||
|
||||
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
||||
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
auto srcOrd = srcLayout.getOrder();
|
||||
auto srcShape = srcTy.getShape();
|
||||
|
||||
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
|
||||
@@ -1487,7 +1498,9 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
||||
SmallVector<Value> writeIdx = indices[key];
|
||||
|
||||
writeIdx[axis] = udiv(writeIdx[axis], sizePerThread);
|
||||
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape);
|
||||
Value writeOffset =
|
||||
linearize(rewriter, loc, reorder<Value>(writeIdx, srcOrd),
|
||||
reorder<unsigned>(smemShape, srcOrd));
|
||||
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
||||
store(acc, writePtr);
|
||||
|
||||
@@ -1495,8 +1508,11 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
||||
for (int N = smemShape[axis] / 2; N > 0; N >>= 1) {
|
||||
readIdx[axis] = ints[N];
|
||||
Value readMask = icmp_slt(writeIdx[axis], ints[N]);
|
||||
Value readOffset = select(
|
||||
readMask, linearize(rewriter, loc, readIdx, smemShape), ints[0]);
|
||||
Value readOffset =
|
||||
select(readMask,
|
||||
linearize(rewriter, loc, reorder<Value>(readIdx, srcOrd),
|
||||
reorder<unsigned>(smemShape, srcOrd)),
|
||||
ints[0]);
|
||||
Value readPtr = gep(elemPtrTy, writePtr, readOffset);
|
||||
barrier();
|
||||
accumulate(rewriter, loc, op.redOp(), acc, load(readPtr), false);
|
||||
@@ -1519,7 +1535,9 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
||||
for (unsigned i = 0; i < resultElems; ++i) {
|
||||
SmallVector<Value> readIdx = resultIndices[i];
|
||||
readIdx.insert(readIdx.begin() + axis, ints[0]);
|
||||
Value readOffset = linearize(rewriter, loc, readIdx, smemShape);
|
||||
Value readOffset =
|
||||
linearize(rewriter, loc, reorder<Value>(readIdx, srcOrd),
|
||||
reorder<unsigned>(smemShape, srcOrd));
|
||||
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
||||
resultVals[i] = load(readPtr);
|
||||
}
|
||||
@@ -1548,6 +1566,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
||||
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto srcRank = srcTy.getRank();
|
||||
|
||||
auto threadsPerWarp = srcLayout.getThreadsPerWarp();
|
||||
auto warpsPerCTA = srcLayout.getWarpsPerCTA();
|
||||
@@ -1592,6 +1611,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
|
||||
SmallVector<Value> multiDimWarpId =
|
||||
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
|
||||
|
||||
Value laneIdAxis = multiDimLaneId[axis];
|
||||
Value warpIdAxis = multiDimWarpId[axis];
|
||||
|
||||
@@ -1609,56 +1629,77 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
|
||||
}
|
||||
|
||||
if (sizeInterWarps == 1) {
|
||||
SmallVector<Value> writeIdx = indices[key];
|
||||
writeIdx[axis] = zero;
|
||||
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape);
|
||||
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
||||
storeShared(rewriter, loc, writePtr, acc, laneZero);
|
||||
} else {
|
||||
SmallVector<Value> writeIdx = indices[key];
|
||||
writeIdx[axis] =
|
||||
warpIdAxis; // axis must be the fastest-changing dimension
|
||||
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape);
|
||||
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
||||
storeShared(rewriter, loc, writePtr, acc, laneZero);
|
||||
barrier();
|
||||
SmallVector<Value> writeIdx = indices[key];
|
||||
writeIdx[axis] = (sizeInterWarps == 1) ? zero : warpIdAxis;
|
||||
Value writeOffset =
|
||||
linearize(rewriter, loc, reorder<Value>(writeIdx, order),
|
||||
reorder<unsigned>(smemShape, order));
|
||||
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
||||
storeShared(rewriter, loc, writePtr, acc, laneZero);
|
||||
}
|
||||
|
||||
SmallVector<Value> readIdx = writeIdx;
|
||||
readIdx[axis] = urem(laneId, i32_val(sizeInterWarps));
|
||||
Value readOffset = linearize(rewriter, loc, readIdx, smemShape);
|
||||
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
||||
acc = load(readPtr);
|
||||
barrier();
|
||||
|
||||
// reduce across warps
|
||||
for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) {
|
||||
Value shfl = shflSync(rewriter, loc, acc, N);
|
||||
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
|
||||
}
|
||||
// the second round of shuffle reduction
|
||||
// now the problem size: sizeInterWarps, s1, s2, .. , sn =>
|
||||
// 1, s1, s2, .. , sn
|
||||
// where sizeInterWarps is 2^m
|
||||
//
|
||||
// each thread needs to process:
|
||||
// elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads
|
||||
unsigned elems = product<unsigned>(smemShape);
|
||||
unsigned numThreads = product<unsigned>(srcLayout.getWarpsPerCTA()) * 32;
|
||||
unsigned elemsPerThread = std::max<unsigned>(elems / numThreads, 1);
|
||||
Value readOffset = threadId;
|
||||
for (unsigned round = 0; round < elemsPerThread; ++round) {
|
||||
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
||||
Value acc = load(readPtr);
|
||||
|
||||
writeIdx[axis] = zero;
|
||||
writeOffset = linearize(rewriter, loc, writeIdx, smemShape);
|
||||
writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
||||
storeShared(rewriter, loc, writePtr, acc, and_(laneZero, warpZero));
|
||||
for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) {
|
||||
Value shfl = shflSync(rewriter, loc, acc, N);
|
||||
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
|
||||
}
|
||||
|
||||
Value writeOffset = udiv(readOffset, i32_val(sizeInterWarps));
|
||||
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
||||
Value threadIsNeeded = icmp_slt(threadId, i32_val(elems));
|
||||
Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps));
|
||||
Value laneIdModSizeInterWarpsIsZero =
|
||||
icmp_eq(laneIdModSizeInterWarps, zero);
|
||||
storeShared(rewriter, loc, writePtr, acc,
|
||||
and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero));
|
||||
|
||||
if (round != elemsPerThread - 1) {
|
||||
readOffset = add(readOffset, i32_val(numThreads));
|
||||
}
|
||||
}
|
||||
|
||||
// We could avoid this barrier in some of the layouts, however this is not
|
||||
// the general case. TODO: optimize the barrier incase the layouts are
|
||||
// accepted.
|
||||
barrier();
|
||||
|
||||
// set output values
|
||||
if (auto resultTy = op.getType().dyn_cast<RankedTensorType>()) {
|
||||
// nd-tensor where n >= 1
|
||||
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
|
||||
auto resultShape = resultTy.getShape();
|
||||
SmallVector<unsigned> resultOrd;
|
||||
for (auto ord : order) {
|
||||
if (ord != 0)
|
||||
resultOrd.push_back(ord - 1);
|
||||
}
|
||||
|
||||
unsigned resultElems = getElemsPerThread(resultTy);
|
||||
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultShape);
|
||||
assert(resultIndices.size() == resultElems);
|
||||
|
||||
barrier();
|
||||
SmallVector<Value> resultVals(resultElems);
|
||||
for (size_t i = 0; i < resultElems; ++i) {
|
||||
SmallVector<Value> readIdx = resultIndices[i];
|
||||
readIdx.insert(readIdx.begin() + axis, i32_val(0));
|
||||
Value readOffset = linearize(rewriter, loc, readIdx, smemShape);
|
||||
Value readOffset =
|
||||
linearize(rewriter, loc, reorder<Value>(readIdx, resultOrd),
|
||||
reorder<int64_t, unsigned>(resultShape, resultOrd));
|
||||
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
||||
resultVals[i] = load(readPtr);
|
||||
}
|
||||
@@ -1670,7 +1711,6 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
rewriter.replaceOp(op, ret);
|
||||
} else {
|
||||
// 0d-tensor -> scalar
|
||||
barrier();
|
||||
Value resultVal = load(smemBase);
|
||||
rewriter.replaceOp(op, resultVal);
|
||||
}
|
||||
@@ -1707,6 +1747,191 @@ struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
||||
}
|
||||
};
|
||||
|
||||
struct PrintfOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::PrintfOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::PrintfOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::PrintfOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
SmallVector<Value, 16> operands;
|
||||
for (auto operand : adaptor.getOperands()) {
|
||||
auto sub_operands = this->getElementsFromStruct(loc, operand, rewriter);
|
||||
for (auto elem : sub_operands) {
|
||||
operands.push_back(elem);
|
||||
}
|
||||
}
|
||||
std::string formatStr;
|
||||
llvm::raw_string_ostream os(formatStr);
|
||||
os << op.prefix();
|
||||
if (operands.size() > 0) {
|
||||
os << getFormatSubstr(operands[0]);
|
||||
}
|
||||
|
||||
for (size_t i = 1; i < operands.size(); ++i) {
|
||||
os << ", " << getFormatSubstr(operands[i]);
|
||||
}
|
||||
llPrintf(formatStr, operands, rewriter);
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
// get format specific for each input value
|
||||
// currently support pointer, i8, i16, i32, i64, f16, bf16, f32, f64
|
||||
std::string getFormatSubstr(Value value) const {
|
||||
Type type = value.getType();
|
||||
unsigned width = type.getIntOrFloatBitWidth();
|
||||
|
||||
if (type.isa<LLVM::LLVMPointerType>()) {
|
||||
return "%p";
|
||||
} else if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) {
|
||||
return "%f";
|
||||
} else if (type.isSignedInteger()) {
|
||||
return "%i";
|
||||
} else if (type.isUnsignedInteger() || type.isSignlessInteger()) {
|
||||
return "%u";
|
||||
}
|
||||
assert(false && "not supported type");
|
||||
}
|
||||
|
||||
// declare vprintf(i8*, i8*) as external function
|
||||
static LLVM::LLVMFuncOp
|
||||
getVprintfDeclaration(ConversionPatternRewriter &rewriter) {
|
||||
auto moduleOp =
|
||||
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
|
||||
StringRef funcName("vprintf");
|
||||
Operation *funcOp = moduleOp.lookupSymbol(funcName);
|
||||
if (funcOp)
|
||||
return cast<LLVM::LLVMFuncOp>(*funcOp);
|
||||
|
||||
auto *context = rewriter.getContext();
|
||||
|
||||
SmallVector<Type> argsType{ptr_ty(IntegerType::get(context, 8)),
|
||||
ptr_ty(IntegerType::get(context, 8))};
|
||||
auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType);
|
||||
|
||||
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
||||
|
||||
return rewriter.create<LLVM::LLVMFuncOp>(UnknownLoc::get(context), funcName,
|
||||
funcType);
|
||||
}
|
||||
|
||||
// extend integer to int32, extend float to float64
|
||||
// this comes from vprintf alignment requirements.
|
||||
static std::pair<Type, Value>
|
||||
promoteValue(ConversionPatternRewriter &rewriter, Value value) {
|
||||
auto *context = rewriter.getContext();
|
||||
auto type = value.getType();
|
||||
type.dump();
|
||||
unsigned width = type.getIntOrFloatBitWidth();
|
||||
Value newOp = value;
|
||||
Type newType = type;
|
||||
|
||||
bool bUnsigned = type.isUnsignedInteger();
|
||||
if (type.isIntOrIndex() && width < 32) {
|
||||
if (bUnsigned) {
|
||||
newType = ui32_ty;
|
||||
newOp = rewriter.create<LLVM::ZExtOp>(UnknownLoc::get(context), newType,
|
||||
value);
|
||||
} else {
|
||||
newType = i32_ty;
|
||||
newOp = rewriter.create<LLVM::SExtOp>(UnknownLoc::get(context), newType,
|
||||
value);
|
||||
}
|
||||
} else if (type.isBF16() || type.isF16() || type.isF32()) {
|
||||
newType = f64_ty;
|
||||
newOp = rewriter.create<LLVM::FPExtOp>(UnknownLoc::get(context), newType,
|
||||
value);
|
||||
}
|
||||
|
||||
return {newType, newOp};
|
||||
}
|
||||
|
||||
static void llPrintf(StringRef msg, ValueRange args,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
static const char formatStringPrefix[] = "printfFormat_";
|
||||
assert(!msg.empty() && "printf with empty string not support");
|
||||
Type int8Ptr = ptr_ty(i8_ty);
|
||||
|
||||
auto *context = rewriter.getContext();
|
||||
auto moduleOp =
|
||||
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
|
||||
auto funcOp = getVprintfDeclaration(rewriter);
|
||||
|
||||
Value one = rewriter.create<LLVM::ConstantOp>(
|
||||
UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(1));
|
||||
Value zero = rewriter.create<LLVM::ConstantOp>(
|
||||
UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(0));
|
||||
|
||||
unsigned stringNumber = 0;
|
||||
SmallString<16> stringConstName;
|
||||
do {
|
||||
stringConstName.clear();
|
||||
(formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
|
||||
} while (moduleOp.lookupSymbol(stringConstName));
|
||||
|
||||
llvm::SmallString<64> formatString(msg);
|
||||
formatString.push_back('\n');
|
||||
formatString.push_back('\0');
|
||||
size_t formatStringSize = formatString.size_in_bytes();
|
||||
auto globalType = LLVM::LLVMArrayType::get(i8_ty, formatStringSize);
|
||||
|
||||
LLVM::GlobalOp global;
|
||||
{
|
||||
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
||||
global = rewriter.create<LLVM::GlobalOp>(
|
||||
UnknownLoc::get(context), globalType,
|
||||
/*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
|
||||
rewriter.getStringAttr(formatString));
|
||||
}
|
||||
|
||||
Value globalPtr =
|
||||
rewriter.create<LLVM::AddressOfOp>(UnknownLoc::get(context), global);
|
||||
Value stringStart =
|
||||
rewriter.create<LLVM::GEPOp>(UnknownLoc::get(context), int8Ptr,
|
||||
globalPtr, mlir::ValueRange({zero, zero}));
|
||||
|
||||
Value bufferPtr =
|
||||
rewriter.create<LLVM::NullOp>(UnknownLoc::get(context), int8Ptr);
|
||||
|
||||
SmallVector<Value, 16> newArgs;
|
||||
if (args.size() >= 1) {
|
||||
SmallVector<Type> argTypes;
|
||||
for (auto arg : args) {
|
||||
Type newType;
|
||||
Value newArg;
|
||||
std::tie(newType, newArg) = promoteValue(rewriter, arg);
|
||||
argTypes.push_back(newType);
|
||||
newArgs.push_back(newArg);
|
||||
}
|
||||
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(context, argTypes);
|
||||
auto allocated = rewriter.create<LLVM::AllocaOp>(UnknownLoc::get(context),
|
||||
ptr_ty(structTy), one,
|
||||
/*alignment=*/0);
|
||||
|
||||
for (const auto &entry : llvm::enumerate(newArgs)) {
|
||||
auto index = rewriter.create<LLVM::ConstantOp>(
|
||||
UnknownLoc::get(context), i32_ty,
|
||||
rewriter.getI32IntegerAttr(entry.index()));
|
||||
auto fieldPtr = rewriter.create<LLVM::GEPOp>(
|
||||
UnknownLoc::get(context), ptr_ty(argTypes[entry.index()]),
|
||||
allocated, ArrayRef<Value>{zero, index});
|
||||
rewriter.create<LLVM::StoreOp>(UnknownLoc::get(context), entry.value(),
|
||||
fieldPtr);
|
||||
}
|
||||
bufferPtr = rewriter.create<LLVM::BitcastOp>(UnknownLoc::get(context),
|
||||
int8Ptr, allocated);
|
||||
}
|
||||
|
||||
ValueRange operands{stringStart, bufferPtr};
|
||||
rewriter.create<LLVM::CallOp>(UnknownLoc::get(context), funcOp, operands);
|
||||
}
|
||||
};
|
||||
|
||||
struct MakeRangeOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp> {
|
||||
|
||||
@@ -2070,17 +2295,6 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
SmallVector<T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) const {
|
||||
size_t rank = order.size();
|
||||
assert(input.size() == rank);
|
||||
SmallVector<T> result(rank);
|
||||
for (auto it : llvm::enumerate(order)) {
|
||||
result[rank - 1 - it.value()] = input[it.index()];
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
// shared memory rd/st for blocked or mma layout with data padding
|
||||
void processReplica(Location loc, ConversionPatternRewriter &rewriter,
|
||||
bool stNotRd, RankedTensorType type,
|
||||
@@ -4483,7 +4697,7 @@ struct InsertSliceAsyncOpConversion
|
||||
auto numVecCols = std::max<unsigned>(inVec / outVec, 1);
|
||||
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcBlockedLayout, srcShape);
|
||||
// <<tileVecIdxRow, tileVecIdxCol>, TileOffset>
|
||||
// <<tileVecIdxRow, tileVecIdxCol>, TileOffset>
|
||||
DenseMap<std::pair<unsigned, unsigned>, Value> tileOffsetMap;
|
||||
for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) {
|
||||
// minVec = 2, inVec = 4, outVec = 2
|
||||
@@ -4674,190 +4888,6 @@ struct FDivOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
struct PrintfOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::PrintfOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::PrintfOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::PrintfOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
SmallVector<Value, 16> operands;
|
||||
for (auto operand : adaptor.getOperands()) {
|
||||
auto sub_operands = this->getElementsFromStruct(loc, operand, rewriter);
|
||||
for (auto elem : sub_operands) {
|
||||
operands.push_back(elem);
|
||||
}
|
||||
}
|
||||
std::string formatStr;
|
||||
llvm::raw_string_ostream os(formatStr);
|
||||
os << op.prefix();
|
||||
if (operands.size() > 0) {
|
||||
os << getFormatSubstr(operands[0]);
|
||||
}
|
||||
|
||||
for (size_t i = 1; i < operands.size(); ++i) {
|
||||
os << ", " << getFormatSubstr(operands[i]);
|
||||
}
|
||||
llPrintf(formatStr, operands, rewriter);
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
// get format specific for each input value
|
||||
// currently support pointer, i8, i16, i32, i64, f16, bf16, f32, f64
|
||||
std::string getFormatSubstr(Value value) const {
|
||||
Type type = value.getType();
|
||||
unsigned width = type.getIntOrFloatBitWidth();
|
||||
|
||||
if (type.isa<LLVM::LLVMPointerType>()) {
|
||||
return "%p";
|
||||
} else if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) {
|
||||
return "%f";
|
||||
} else if (type.isSignedInteger()) {
|
||||
return "%i";
|
||||
} else if (type.isUnsignedInteger() || type.isSignlessInteger()) {
|
||||
return "%u";
|
||||
}
|
||||
assert(false && "not supported type");
|
||||
}
|
||||
|
||||
// declare vprintf(i8*, i8*) as external function
|
||||
LLVM::LLVMFuncOp
|
||||
getVprintfDeclaration(ConversionPatternRewriter &rewriter) const {
|
||||
auto moduleOp =
|
||||
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
|
||||
StringRef funcName("vprintf");
|
||||
Operation *funcOp = moduleOp.lookupSymbol(funcName);
|
||||
if (funcOp)
|
||||
return cast<LLVM::LLVMFuncOp>(*funcOp);
|
||||
|
||||
auto *context = rewriter.getContext();
|
||||
|
||||
SmallVector<Type> argsType{ptr_ty(IntegerType::get(context, 8)),
|
||||
ptr_ty(IntegerType::get(context, 8))};
|
||||
auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType);
|
||||
|
||||
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
||||
|
||||
return rewriter.create<LLVM::LLVMFuncOp>(UnknownLoc::get(context), funcName,
|
||||
funcType);
|
||||
}
|
||||
|
||||
// extend integer to int32, extend float to float64
|
||||
// this comes from vprintf alignment requirements.
|
||||
std::pair<Type, Value> promoteValue(ConversionPatternRewriter &rewriter,
|
||||
Value value) const {
|
||||
auto *context = rewriter.getContext();
|
||||
auto type = value.getType();
|
||||
unsigned width = type.getIntOrFloatBitWidth();
|
||||
Value newOp = value;
|
||||
Type newType = type;
|
||||
|
||||
bool bUnsigned = type.isUnsignedInteger();
|
||||
if (type.isIntOrIndex() && width < 32) {
|
||||
if (bUnsigned) {
|
||||
newType = ui32_ty;
|
||||
newOp = rewriter.create<LLVM::ZExtOp>(UnknownLoc::get(context), newType,
|
||||
value);
|
||||
} else {
|
||||
newType = i32_ty;
|
||||
newOp = rewriter.create<LLVM::SExtOp>(UnknownLoc::get(context), newType,
|
||||
value);
|
||||
}
|
||||
} else if (type.isBF16() || type.isF16() || type.isF32()) {
|
||||
newType = f64_ty;
|
||||
newOp = rewriter.create<LLVM::FPExtOp>(UnknownLoc::get(context), newType,
|
||||
value);
|
||||
}
|
||||
|
||||
return {newType, newOp};
|
||||
}
|
||||
|
||||
void llPrintf(StringRef msg, ValueRange args,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
static const char formatStringPrefix[] = "printfFormat_";
|
||||
assert(!msg.empty() && "printf with empty string not support");
|
||||
Type int8Ptr = ptr_ty(i8_ty);
|
||||
|
||||
auto *context = rewriter.getContext();
|
||||
auto moduleOp =
|
||||
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
|
||||
auto funcOp = getVprintfDeclaration(rewriter);
|
||||
|
||||
Value one = rewriter.create<LLVM::ConstantOp>(
|
||||
UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(1));
|
||||
Value zero = rewriter.create<LLVM::ConstantOp>(
|
||||
UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(0));
|
||||
|
||||
unsigned stringNumber = 0;
|
||||
SmallString<16> stringConstName;
|
||||
do {
|
||||
stringConstName.clear();
|
||||
(formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
|
||||
} while (moduleOp.lookupSymbol(stringConstName));
|
||||
|
||||
llvm::SmallString<64> formatString(msg);
|
||||
formatString.push_back('\n');
|
||||
formatString.push_back('\0');
|
||||
size_t formatStringSize = formatString.size_in_bytes();
|
||||
auto globalType = LLVM::LLVMArrayType::get(i8_ty, formatStringSize);
|
||||
|
||||
LLVM::GlobalOp global;
|
||||
{
|
||||
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
||||
global = rewriter.create<LLVM::GlobalOp>(
|
||||
UnknownLoc::get(context), globalType,
|
||||
/*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
|
||||
rewriter.getStringAttr(formatString));
|
||||
}
|
||||
|
||||
Value globalPtr =
|
||||
rewriter.create<LLVM::AddressOfOp>(UnknownLoc::get(context), global);
|
||||
Value stringStart =
|
||||
rewriter.create<LLVM::GEPOp>(UnknownLoc::get(context), int8Ptr,
|
||||
globalPtr, mlir::ValueRange({zero, zero}));
|
||||
|
||||
Value bufferPtr =
|
||||
rewriter.create<LLVM::NullOp>(UnknownLoc::get(context), int8Ptr);
|
||||
|
||||
SmallVector<Value, 16> newArgs;
|
||||
if (args.size() >= 1) {
|
||||
SmallVector<Type> argTypes;
|
||||
for (auto arg : args) {
|
||||
Type newType;
|
||||
Value newArg;
|
||||
std::tie(newType, newArg) = promoteValue(rewriter, arg);
|
||||
argTypes.push_back(newType);
|
||||
newArgs.push_back(newArg);
|
||||
}
|
||||
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(context, argTypes);
|
||||
auto allocated = rewriter.create<LLVM::AllocaOp>(UnknownLoc::get(context),
|
||||
ptr_ty(structTy), one,
|
||||
/*alignment=*/0);
|
||||
|
||||
for (const auto &entry : llvm::enumerate(newArgs)) {
|
||||
auto index = rewriter.create<LLVM::ConstantOp>(
|
||||
UnknownLoc::get(context), i32_ty,
|
||||
rewriter.getI32IntegerAttr(entry.index()));
|
||||
auto fieldPtr = rewriter.create<LLVM::GEPOp>(
|
||||
UnknownLoc::get(context), ptr_ty(argTypes[entry.index()]),
|
||||
allocated, ArrayRef<Value>{zero, index});
|
||||
rewriter.create<LLVM::StoreOp>(UnknownLoc::get(context), entry.value(),
|
||||
fieldPtr);
|
||||
}
|
||||
bufferPtr = rewriter.create<LLVM::BitcastOp>(UnknownLoc::get(context),
|
||||
int8Ptr, allocated);
|
||||
}
|
||||
|
||||
ValueRange operands{stringStart, bufferPtr};
|
||||
rewriter.create<LLVM::CallOp>(UnknownLoc::get(context), funcOp, operands);
|
||||
}
|
||||
};
|
||||
|
||||
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns, int numWarps,
|
||||
AxisInfoAnalysis &axisInfoAnalysis,
|
||||
@@ -5062,6 +5092,15 @@ void ConvertTritonGPUToLLVM::initSharedMemory(
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace LLVM {
|
||||
|
||||
void llPrintf(StringRef msg, ValueRange args,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
PrintfOpConversion::llPrintf(msg, args, rewriter);
|
||||
}
|
||||
|
||||
} // namespace LLVM
|
||||
|
||||
TritonLLVMConversionTarget::TritonLLVMConversionTarget(
|
||||
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter)
|
||||
: ConversionTarget(ctx) {
|
||||
|
Reference in New Issue
Block a user