|
|
|
@@ -83,6 +83,11 @@ static Value createLLVMIntegerConstant(OpBuilder &builder, Location loc,
|
|
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
// A helper function for using printf in LLVM conversion.
|
|
|
|
|
void llPrintf(StringRef msg, ValueRange args,
|
|
|
|
|
ConversionPatternRewriter &rewriter);
|
|
|
|
|
|
|
|
|
|
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive//
|
|
|
|
|
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
|
|
|
|
|
#define zext(...) rewriter.create<LLVM::ZExtOp>(loc, __VA_ARGS__)
|
|
|
|
|
#define udiv(...) rewriter.create<LLVM::UDivOp>(loc, __VA_ARGS__)
|
|
|
|
@@ -338,6 +343,7 @@ Value getStructFromElements(Location loc, ValueRange resultVals,
|
|
|
|
|
return llvmStruct;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Delinearize on compile-time consts, assuming the order is [n, .. 2, 1, 0]
|
|
|
|
|
template <typename T>
|
|
|
|
|
static SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape) {
|
|
|
|
|
// shape: {a, b, c, d} -> accMul: {b*c*d, c*d, d, 1}
|
|
|
|
@@ -355,6 +361,7 @@ static SmallVector<T> getMultiDimIndex(T linearIndex, ArrayRef<T> shape) {
|
|
|
|
|
return multiDimIndex;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Linearize on compile-time consts, assuming the order is [n, .. 2, 1, 0]
|
|
|
|
|
template <typename T>
|
|
|
|
|
static T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
|
|
|
|
|
assert(multiDimIndex.size() == shape.size());
|
|
|
|
@@ -510,12 +517,12 @@ public:
|
|
|
|
|
multiDim[0] = linear;
|
|
|
|
|
} else {
|
|
|
|
|
Value remained = linear;
|
|
|
|
|
for (auto &&en : llvm::enumerate(llvm::reverse(shape.drop_front()))) {
|
|
|
|
|
for (auto &&en : llvm::enumerate(shape.drop_back())) {
|
|
|
|
|
Value dimSize = idx_val(en.value());
|
|
|
|
|
multiDim[rank - 1 - en.index()] = urem(remained, dimSize);
|
|
|
|
|
multiDim[en.index()] = urem(remained, dimSize);
|
|
|
|
|
remained = udiv(remained, dimSize);
|
|
|
|
|
}
|
|
|
|
|
multiDim[0] = remained;
|
|
|
|
|
multiDim[rank - 1] = remained;
|
|
|
|
|
}
|
|
|
|
|
return multiDim;
|
|
|
|
|
}
|
|
|
|
@@ -525,9 +532,9 @@ public:
|
|
|
|
|
int rank = multiDim.size();
|
|
|
|
|
Value linear = idx_val(0);
|
|
|
|
|
if (rank > 0) {
|
|
|
|
|
linear = multiDim.front();
|
|
|
|
|
linear = multiDim.back();
|
|
|
|
|
for (auto [dim, shape] :
|
|
|
|
|
llvm::zip(multiDim.drop_front(), shape.drop_front())) {
|
|
|
|
|
llvm::reverse(llvm::zip(multiDim.drop_back(), shape.drop_back()))) {
|
|
|
|
|
Value dimSize = idx_val(shape);
|
|
|
|
|
linear = add(mul(linear, dimSize), dim);
|
|
|
|
|
}
|
|
|
|
@@ -566,6 +573,7 @@ public:
|
|
|
|
|
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
|
|
|
|
|
SmallVector<Value> multiDimThreadId =
|
|
|
|
|
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> multiDimBase(rank);
|
|
|
|
|
for (unsigned k = 0; k < rank; ++k) {
|
|
|
|
|
// Wrap around multiDimWarpId/multiDimThreadId incase
|
|
|
|
@@ -1362,7 +1370,9 @@ private:
|
|
|
|
|
LogicalResult
|
|
|
|
|
ReduceOpConversion::matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
if (op.axis() == 1) // FIXME(Qingyi): The fastest-changing dimension
|
|
|
|
|
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
|
|
|
|
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
|
|
|
|
if (op.axis() == srcLayout.getOrder()[0])
|
|
|
|
|
return matchAndRewriteFast(op, adaptor, rewriter);
|
|
|
|
|
return matchAndRewriteBasic(op, adaptor, rewriter);
|
|
|
|
|
}
|
|
|
|
@@ -1444,6 +1454,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
|
|
|
|
|
|
|
|
|
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
|
|
|
|
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
|
|
|
|
auto srcOrd = srcLayout.getOrder();
|
|
|
|
|
auto srcShape = srcTy.getShape();
|
|
|
|
|
|
|
|
|
|
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
|
|
|
|
@@ -1487,7 +1498,9 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
|
|
|
|
SmallVector<Value> writeIdx = indices[key];
|
|
|
|
|
|
|
|
|
|
writeIdx[axis] = udiv(writeIdx[axis], sizePerThread);
|
|
|
|
|
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape);
|
|
|
|
|
Value writeOffset =
|
|
|
|
|
linearize(rewriter, loc, reorder<Value>(writeIdx, srcOrd),
|
|
|
|
|
reorder<unsigned>(smemShape, srcOrd));
|
|
|
|
|
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
|
|
|
|
store(acc, writePtr);
|
|
|
|
|
|
|
|
|
@@ -1495,8 +1508,11 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
|
|
|
|
for (int N = smemShape[axis] / 2; N > 0; N >>= 1) {
|
|
|
|
|
readIdx[axis] = ints[N];
|
|
|
|
|
Value readMask = icmp_slt(writeIdx[axis], ints[N]);
|
|
|
|
|
Value readOffset = select(
|
|
|
|
|
readMask, linearize(rewriter, loc, readIdx, smemShape), ints[0]);
|
|
|
|
|
Value readOffset =
|
|
|
|
|
select(readMask,
|
|
|
|
|
linearize(rewriter, loc, reorder<Value>(readIdx, srcOrd),
|
|
|
|
|
reorder<unsigned>(smemShape, srcOrd)),
|
|
|
|
|
ints[0]);
|
|
|
|
|
Value readPtr = gep(elemPtrTy, writePtr, readOffset);
|
|
|
|
|
barrier();
|
|
|
|
|
accumulate(rewriter, loc, op.redOp(), acc, load(readPtr), false);
|
|
|
|
@@ -1519,7 +1535,9 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
|
|
|
|
for (unsigned i = 0; i < resultElems; ++i) {
|
|
|
|
|
SmallVector<Value> readIdx = resultIndices[i];
|
|
|
|
|
readIdx.insert(readIdx.begin() + axis, ints[0]);
|
|
|
|
|
Value readOffset = linearize(rewriter, loc, readIdx, smemShape);
|
|
|
|
|
Value readOffset =
|
|
|
|
|
linearize(rewriter, loc, reorder<Value>(readIdx, srcOrd),
|
|
|
|
|
reorder<unsigned>(smemShape, srcOrd));
|
|
|
|
|
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
|
|
|
|
resultVals[i] = load(readPtr);
|
|
|
|
|
}
|
|
|
|
@@ -1548,6 +1566,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
|
|
|
|
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
|
|
|
|
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
|
|
|
|
auto srcShape = srcTy.getShape();
|
|
|
|
|
auto srcRank = srcTy.getRank();
|
|
|
|
|
|
|
|
|
|
auto threadsPerWarp = srcLayout.getThreadsPerWarp();
|
|
|
|
|
auto warpsPerCTA = srcLayout.getWarpsPerCTA();
|
|
|
|
@@ -1592,6 +1611,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
|
|
|
|
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
|
|
|
|
|
SmallVector<Value> multiDimWarpId =
|
|
|
|
|
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
|
|
|
|
|
|
|
|
|
|
Value laneIdAxis = multiDimLaneId[axis];
|
|
|
|
|
Value warpIdAxis = multiDimWarpId[axis];
|
|
|
|
|
|
|
|
|
@@ -1609,56 +1629,77 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
|
|
|
|
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (sizeInterWarps == 1) {
|
|
|
|
|
SmallVector<Value> writeIdx = indices[key];
|
|
|
|
|
writeIdx[axis] = zero;
|
|
|
|
|
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape);
|
|
|
|
|
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
|
|
|
|
storeShared(rewriter, loc, writePtr, acc, laneZero);
|
|
|
|
|
} else {
|
|
|
|
|
SmallVector<Value> writeIdx = indices[key];
|
|
|
|
|
writeIdx[axis] =
|
|
|
|
|
warpIdAxis; // axis must be the fastest-changing dimension
|
|
|
|
|
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape);
|
|
|
|
|
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
|
|
|
|
storeShared(rewriter, loc, writePtr, acc, laneZero);
|
|
|
|
|
barrier();
|
|
|
|
|
SmallVector<Value> writeIdx = indices[key];
|
|
|
|
|
writeIdx[axis] = (sizeInterWarps == 1) ? zero : warpIdAxis;
|
|
|
|
|
Value writeOffset =
|
|
|
|
|
linearize(rewriter, loc, reorder<Value>(writeIdx, order),
|
|
|
|
|
reorder<unsigned>(smemShape, order));
|
|
|
|
|
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
|
|
|
|
storeShared(rewriter, loc, writePtr, acc, laneZero);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> readIdx = writeIdx;
|
|
|
|
|
readIdx[axis] = urem(laneId, i32_val(sizeInterWarps));
|
|
|
|
|
Value readOffset = linearize(rewriter, loc, readIdx, smemShape);
|
|
|
|
|
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
|
|
|
|
acc = load(readPtr);
|
|
|
|
|
barrier();
|
|
|
|
|
|
|
|
|
|
// reduce across warps
|
|
|
|
|
for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) {
|
|
|
|
|
Value shfl = shflSync(rewriter, loc, acc, N);
|
|
|
|
|
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
|
|
|
|
|
}
|
|
|
|
|
// the second round of shuffle reduction
|
|
|
|
|
// now the problem size: sizeInterWarps, s1, s2, .. , sn =>
|
|
|
|
|
// 1, s1, s2, .. , sn
|
|
|
|
|
// where sizeInterWarps is 2^m
|
|
|
|
|
//
|
|
|
|
|
// each thread needs to process:
|
|
|
|
|
// elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads
|
|
|
|
|
unsigned elems = product<unsigned>(smemShape);
|
|
|
|
|
unsigned numThreads = product<unsigned>(srcLayout.getWarpsPerCTA()) * 32;
|
|
|
|
|
unsigned elemsPerThread = std::max<unsigned>(elems / numThreads, 1);
|
|
|
|
|
Value readOffset = threadId;
|
|
|
|
|
for (unsigned round = 0; round < elemsPerThread; ++round) {
|
|
|
|
|
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
|
|
|
|
Value acc = load(readPtr);
|
|
|
|
|
|
|
|
|
|
writeIdx[axis] = zero;
|
|
|
|
|
writeOffset = linearize(rewriter, loc, writeIdx, smemShape);
|
|
|
|
|
writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
|
|
|
|
storeShared(rewriter, loc, writePtr, acc, and_(laneZero, warpZero));
|
|
|
|
|
for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) {
|
|
|
|
|
Value shfl = shflSync(rewriter, loc, acc, N);
|
|
|
|
|
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value writeOffset = udiv(readOffset, i32_val(sizeInterWarps));
|
|
|
|
|
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
|
|
|
|
Value threadIsNeeded = icmp_slt(threadId, i32_val(elems));
|
|
|
|
|
Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps));
|
|
|
|
|
Value laneIdModSizeInterWarpsIsZero =
|
|
|
|
|
icmp_eq(laneIdModSizeInterWarps, zero);
|
|
|
|
|
storeShared(rewriter, loc, writePtr, acc,
|
|
|
|
|
and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero));
|
|
|
|
|
|
|
|
|
|
if (round != elemsPerThread - 1) {
|
|
|
|
|
readOffset = add(readOffset, i32_val(numThreads));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// We could avoid this barrier in some of the layouts, however this is not
|
|
|
|
|
// the general case. TODO: optimize the barrier incase the layouts are
|
|
|
|
|
// accepted.
|
|
|
|
|
barrier();
|
|
|
|
|
|
|
|
|
|
// set output values
|
|
|
|
|
if (auto resultTy = op.getType().dyn_cast<RankedTensorType>()) {
|
|
|
|
|
// nd-tensor where n >= 1
|
|
|
|
|
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
|
|
|
|
|
auto resultShape = resultTy.getShape();
|
|
|
|
|
SmallVector<unsigned> resultOrd;
|
|
|
|
|
for (auto ord : order) {
|
|
|
|
|
if (ord != 0)
|
|
|
|
|
resultOrd.push_back(ord - 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
unsigned resultElems = getElemsPerThread(resultTy);
|
|
|
|
|
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultShape);
|
|
|
|
|
assert(resultIndices.size() == resultElems);
|
|
|
|
|
|
|
|
|
|
barrier();
|
|
|
|
|
SmallVector<Value> resultVals(resultElems);
|
|
|
|
|
for (size_t i = 0; i < resultElems; ++i) {
|
|
|
|
|
SmallVector<Value> readIdx = resultIndices[i];
|
|
|
|
|
readIdx.insert(readIdx.begin() + axis, i32_val(0));
|
|
|
|
|
Value readOffset = linearize(rewriter, loc, readIdx, smemShape);
|
|
|
|
|
Value readOffset =
|
|
|
|
|
linearize(rewriter, loc, reorder<Value>(readIdx, resultOrd),
|
|
|
|
|
reorder<int64_t, unsigned>(resultShape, resultOrd));
|
|
|
|
|
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
|
|
|
|
resultVals[i] = load(readPtr);
|
|
|
|
|
}
|
|
|
|
@@ -1670,7 +1711,6 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
|
|
|
|
rewriter.replaceOp(op, ret);
|
|
|
|
|
} else {
|
|
|
|
|
// 0d-tensor -> scalar
|
|
|
|
|
barrier();
|
|
|
|
|
Value resultVal = load(smemBase);
|
|
|
|
|
rewriter.replaceOp(op, resultVal);
|
|
|
|
|
}
|
|
|
|
@@ -1707,6 +1747,191 @@ struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct PrintfOpConversion
|
|
|
|
|
: public ConvertTritonGPUOpToLLVMPattern<triton::PrintfOp> {
|
|
|
|
|
using ConvertTritonGPUOpToLLVMPattern<
|
|
|
|
|
triton::PrintfOp>::ConvertTritonGPUOpToLLVMPattern;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(triton::PrintfOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto loc = op->getLoc();
|
|
|
|
|
SmallVector<Value, 16> operands;
|
|
|
|
|
for (auto operand : adaptor.getOperands()) {
|
|
|
|
|
auto sub_operands = this->getElementsFromStruct(loc, operand, rewriter);
|
|
|
|
|
for (auto elem : sub_operands) {
|
|
|
|
|
operands.push_back(elem);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
std::string formatStr;
|
|
|
|
|
llvm::raw_string_ostream os(formatStr);
|
|
|
|
|
os << op.prefix();
|
|
|
|
|
if (operands.size() > 0) {
|
|
|
|
|
os << getFormatSubstr(operands[0]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t i = 1; i < operands.size(); ++i) {
|
|
|
|
|
os << ", " << getFormatSubstr(operands[i]);
|
|
|
|
|
}
|
|
|
|
|
llPrintf(formatStr, operands, rewriter);
|
|
|
|
|
rewriter.eraseOp(op);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
// get format specific for each input value
|
|
|
|
|
// currently support pointer, i8, i16, i32, i64, f16, bf16, f32, f64
|
|
|
|
|
std::string getFormatSubstr(Value value) const {
|
|
|
|
|
Type type = value.getType();
|
|
|
|
|
unsigned width = type.getIntOrFloatBitWidth();
|
|
|
|
|
|
|
|
|
|
if (type.isa<LLVM::LLVMPointerType>()) {
|
|
|
|
|
return "%p";
|
|
|
|
|
} else if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) {
|
|
|
|
|
return "%f";
|
|
|
|
|
} else if (type.isSignedInteger()) {
|
|
|
|
|
return "%i";
|
|
|
|
|
} else if (type.isUnsignedInteger() || type.isSignlessInteger()) {
|
|
|
|
|
return "%u";
|
|
|
|
|
}
|
|
|
|
|
assert(false && "not supported type");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// declare vprintf(i8*, i8*) as external function
|
|
|
|
|
static LLVM::LLVMFuncOp
|
|
|
|
|
getVprintfDeclaration(ConversionPatternRewriter &rewriter) {
|
|
|
|
|
auto moduleOp =
|
|
|
|
|
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
|
|
|
|
|
StringRef funcName("vprintf");
|
|
|
|
|
Operation *funcOp = moduleOp.lookupSymbol(funcName);
|
|
|
|
|
if (funcOp)
|
|
|
|
|
return cast<LLVM::LLVMFuncOp>(*funcOp);
|
|
|
|
|
|
|
|
|
|
auto *context = rewriter.getContext();
|
|
|
|
|
|
|
|
|
|
SmallVector<Type> argsType{ptr_ty(IntegerType::get(context, 8)),
|
|
|
|
|
ptr_ty(IntegerType::get(context, 8))};
|
|
|
|
|
auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType);
|
|
|
|
|
|
|
|
|
|
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
|
|
|
|
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
|
|
|
|
|
|
|
|
|
return rewriter.create<LLVM::LLVMFuncOp>(UnknownLoc::get(context), funcName,
|
|
|
|
|
funcType);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// extend integer to int32, extend float to float64
|
|
|
|
|
// this comes from vprintf alignment requirements.
|
|
|
|
|
static std::pair<Type, Value>
|
|
|
|
|
promoteValue(ConversionPatternRewriter &rewriter, Value value) {
|
|
|
|
|
auto *context = rewriter.getContext();
|
|
|
|
|
auto type = value.getType();
|
|
|
|
|
type.dump();
|
|
|
|
|
unsigned width = type.getIntOrFloatBitWidth();
|
|
|
|
|
Value newOp = value;
|
|
|
|
|
Type newType = type;
|
|
|
|
|
|
|
|
|
|
bool bUnsigned = type.isUnsignedInteger();
|
|
|
|
|
if (type.isIntOrIndex() && width < 32) {
|
|
|
|
|
if (bUnsigned) {
|
|
|
|
|
newType = ui32_ty;
|
|
|
|
|
newOp = rewriter.create<LLVM::ZExtOp>(UnknownLoc::get(context), newType,
|
|
|
|
|
value);
|
|
|
|
|
} else {
|
|
|
|
|
newType = i32_ty;
|
|
|
|
|
newOp = rewriter.create<LLVM::SExtOp>(UnknownLoc::get(context), newType,
|
|
|
|
|
value);
|
|
|
|
|
}
|
|
|
|
|
} else if (type.isBF16() || type.isF16() || type.isF32()) {
|
|
|
|
|
newType = f64_ty;
|
|
|
|
|
newOp = rewriter.create<LLVM::FPExtOp>(UnknownLoc::get(context), newType,
|
|
|
|
|
value);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return {newType, newOp};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void llPrintf(StringRef msg, ValueRange args,
|
|
|
|
|
ConversionPatternRewriter &rewriter) {
|
|
|
|
|
static const char formatStringPrefix[] = "printfFormat_";
|
|
|
|
|
assert(!msg.empty() && "printf with empty string not support");
|
|
|
|
|
Type int8Ptr = ptr_ty(i8_ty);
|
|
|
|
|
|
|
|
|
|
auto *context = rewriter.getContext();
|
|
|
|
|
auto moduleOp =
|
|
|
|
|
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
|
|
|
|
|
auto funcOp = getVprintfDeclaration(rewriter);
|
|
|
|
|
|
|
|
|
|
Value one = rewriter.create<LLVM::ConstantOp>(
|
|
|
|
|
UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(1));
|
|
|
|
|
Value zero = rewriter.create<LLVM::ConstantOp>(
|
|
|
|
|
UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(0));
|
|
|
|
|
|
|
|
|
|
unsigned stringNumber = 0;
|
|
|
|
|
SmallString<16> stringConstName;
|
|
|
|
|
do {
|
|
|
|
|
stringConstName.clear();
|
|
|
|
|
(formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
|
|
|
|
|
} while (moduleOp.lookupSymbol(stringConstName));
|
|
|
|
|
|
|
|
|
|
llvm::SmallString<64> formatString(msg);
|
|
|
|
|
formatString.push_back('\n');
|
|
|
|
|
formatString.push_back('\0');
|
|
|
|
|
size_t formatStringSize = formatString.size_in_bytes();
|
|
|
|
|
auto globalType = LLVM::LLVMArrayType::get(i8_ty, formatStringSize);
|
|
|
|
|
|
|
|
|
|
LLVM::GlobalOp global;
|
|
|
|
|
{
|
|
|
|
|
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
|
|
|
|
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
|
|
|
|
global = rewriter.create<LLVM::GlobalOp>(
|
|
|
|
|
UnknownLoc::get(context), globalType,
|
|
|
|
|
/*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
|
|
|
|
|
rewriter.getStringAttr(formatString));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value globalPtr =
|
|
|
|
|
rewriter.create<LLVM::AddressOfOp>(UnknownLoc::get(context), global);
|
|
|
|
|
Value stringStart =
|
|
|
|
|
rewriter.create<LLVM::GEPOp>(UnknownLoc::get(context), int8Ptr,
|
|
|
|
|
globalPtr, mlir::ValueRange({zero, zero}));
|
|
|
|
|
|
|
|
|
|
Value bufferPtr =
|
|
|
|
|
rewriter.create<LLVM::NullOp>(UnknownLoc::get(context), int8Ptr);
|
|
|
|
|
|
|
|
|
|
SmallVector<Value, 16> newArgs;
|
|
|
|
|
if (args.size() >= 1) {
|
|
|
|
|
SmallVector<Type> argTypes;
|
|
|
|
|
for (auto arg : args) {
|
|
|
|
|
Type newType;
|
|
|
|
|
Value newArg;
|
|
|
|
|
std::tie(newType, newArg) = promoteValue(rewriter, arg);
|
|
|
|
|
argTypes.push_back(newType);
|
|
|
|
|
newArgs.push_back(newArg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Type structTy = LLVM::LLVMStructType::getLiteral(context, argTypes);
|
|
|
|
|
auto allocated = rewriter.create<LLVM::AllocaOp>(UnknownLoc::get(context),
|
|
|
|
|
ptr_ty(structTy), one,
|
|
|
|
|
/*alignment=*/0);
|
|
|
|
|
|
|
|
|
|
for (const auto &entry : llvm::enumerate(newArgs)) {
|
|
|
|
|
auto index = rewriter.create<LLVM::ConstantOp>(
|
|
|
|
|
UnknownLoc::get(context), i32_ty,
|
|
|
|
|
rewriter.getI32IntegerAttr(entry.index()));
|
|
|
|
|
auto fieldPtr = rewriter.create<LLVM::GEPOp>(
|
|
|
|
|
UnknownLoc::get(context), ptr_ty(argTypes[entry.index()]),
|
|
|
|
|
allocated, ArrayRef<Value>{zero, index});
|
|
|
|
|
rewriter.create<LLVM::StoreOp>(UnknownLoc::get(context), entry.value(),
|
|
|
|
|
fieldPtr);
|
|
|
|
|
}
|
|
|
|
|
bufferPtr = rewriter.create<LLVM::BitcastOp>(UnknownLoc::get(context),
|
|
|
|
|
int8Ptr, allocated);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ValueRange operands{stringStart, bufferPtr};
|
|
|
|
|
rewriter.create<LLVM::CallOp>(UnknownLoc::get(context), funcOp, operands);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct MakeRangeOpConversion
|
|
|
|
|
: public ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp> {
|
|
|
|
|
|
|
|
|
@@ -2070,17 +2295,6 @@ public:
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
template <typename T>
|
|
|
|
|
SmallVector<T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) const {
|
|
|
|
|
size_t rank = order.size();
|
|
|
|
|
assert(input.size() == rank);
|
|
|
|
|
SmallVector<T> result(rank);
|
|
|
|
|
for (auto it : llvm::enumerate(order)) {
|
|
|
|
|
result[rank - 1 - it.value()] = input[it.index()];
|
|
|
|
|
}
|
|
|
|
|
return result;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// shared memory rd/st for blocked or mma layout with data padding
|
|
|
|
|
void processReplica(Location loc, ConversionPatternRewriter &rewriter,
|
|
|
|
|
bool stNotRd, RankedTensorType type,
|
|
|
|
@@ -4483,7 +4697,7 @@ struct InsertSliceAsyncOpConversion
|
|
|
|
|
auto numVecCols = std::max<unsigned>(inVec / outVec, 1);
|
|
|
|
|
|
|
|
|
|
auto srcIndices = emitIndices(loc, rewriter, srcBlockedLayout, srcShape);
|
|
|
|
|
// <<tileVecIdxRow, tileVecIdxCol>, TileOffset>
|
|
|
|
|
// <<tileVecIdxRow, tileVecIdxCol>, TileOffset>
|
|
|
|
|
DenseMap<std::pair<unsigned, unsigned>, Value> tileOffsetMap;
|
|
|
|
|
for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) {
|
|
|
|
|
// minVec = 2, inVec = 4, outVec = 2
|
|
|
|
@@ -4674,190 +4888,6 @@ struct FDivOpConversion
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct PrintfOpConversion
|
|
|
|
|
: public ConvertTritonGPUOpToLLVMPattern<triton::PrintfOp> {
|
|
|
|
|
using ConvertTritonGPUOpToLLVMPattern<
|
|
|
|
|
triton::PrintfOp>::ConvertTritonGPUOpToLLVMPattern;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(triton::PrintfOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto loc = op->getLoc();
|
|
|
|
|
SmallVector<Value, 16> operands;
|
|
|
|
|
for (auto operand : adaptor.getOperands()) {
|
|
|
|
|
auto sub_operands = this->getElementsFromStruct(loc, operand, rewriter);
|
|
|
|
|
for (auto elem : sub_operands) {
|
|
|
|
|
operands.push_back(elem);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
std::string formatStr;
|
|
|
|
|
llvm::raw_string_ostream os(formatStr);
|
|
|
|
|
os << op.prefix();
|
|
|
|
|
if (operands.size() > 0) {
|
|
|
|
|
os << getFormatSubstr(operands[0]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t i = 1; i < operands.size(); ++i) {
|
|
|
|
|
os << ", " << getFormatSubstr(operands[i]);
|
|
|
|
|
}
|
|
|
|
|
llPrintf(formatStr, operands, rewriter);
|
|
|
|
|
rewriter.eraseOp(op);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
// get format specific for each input value
|
|
|
|
|
// currently support pointer, i8, i16, i32, i64, f16, bf16, f32, f64
|
|
|
|
|
std::string getFormatSubstr(Value value) const {
|
|
|
|
|
Type type = value.getType();
|
|
|
|
|
unsigned width = type.getIntOrFloatBitWidth();
|
|
|
|
|
|
|
|
|
|
if (type.isa<LLVM::LLVMPointerType>()) {
|
|
|
|
|
return "%p";
|
|
|
|
|
} else if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) {
|
|
|
|
|
return "%f";
|
|
|
|
|
} else if (type.isSignedInteger()) {
|
|
|
|
|
return "%i";
|
|
|
|
|
} else if (type.isUnsignedInteger() || type.isSignlessInteger()) {
|
|
|
|
|
return "%u";
|
|
|
|
|
}
|
|
|
|
|
assert(false && "not supported type");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// declare vprintf(i8*, i8*) as external function
|
|
|
|
|
LLVM::LLVMFuncOp
|
|
|
|
|
getVprintfDeclaration(ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
auto moduleOp =
|
|
|
|
|
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
|
|
|
|
|
StringRef funcName("vprintf");
|
|
|
|
|
Operation *funcOp = moduleOp.lookupSymbol(funcName);
|
|
|
|
|
if (funcOp)
|
|
|
|
|
return cast<LLVM::LLVMFuncOp>(*funcOp);
|
|
|
|
|
|
|
|
|
|
auto *context = rewriter.getContext();
|
|
|
|
|
|
|
|
|
|
SmallVector<Type> argsType{ptr_ty(IntegerType::get(context, 8)),
|
|
|
|
|
ptr_ty(IntegerType::get(context, 8))};
|
|
|
|
|
auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType);
|
|
|
|
|
|
|
|
|
|
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
|
|
|
|
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
|
|
|
|
|
|
|
|
|
return rewriter.create<LLVM::LLVMFuncOp>(UnknownLoc::get(context), funcName,
|
|
|
|
|
funcType);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// extend integer to int32, extend float to float64
|
|
|
|
|
// this comes from vprintf alignment requirements.
|
|
|
|
|
std::pair<Type, Value> promoteValue(ConversionPatternRewriter &rewriter,
|
|
|
|
|
Value value) const {
|
|
|
|
|
auto *context = rewriter.getContext();
|
|
|
|
|
auto type = value.getType();
|
|
|
|
|
unsigned width = type.getIntOrFloatBitWidth();
|
|
|
|
|
Value newOp = value;
|
|
|
|
|
Type newType = type;
|
|
|
|
|
|
|
|
|
|
bool bUnsigned = type.isUnsignedInteger();
|
|
|
|
|
if (type.isIntOrIndex() && width < 32) {
|
|
|
|
|
if (bUnsigned) {
|
|
|
|
|
newType = ui32_ty;
|
|
|
|
|
newOp = rewriter.create<LLVM::ZExtOp>(UnknownLoc::get(context), newType,
|
|
|
|
|
value);
|
|
|
|
|
} else {
|
|
|
|
|
newType = i32_ty;
|
|
|
|
|
newOp = rewriter.create<LLVM::SExtOp>(UnknownLoc::get(context), newType,
|
|
|
|
|
value);
|
|
|
|
|
}
|
|
|
|
|
} else if (type.isBF16() || type.isF16() || type.isF32()) {
|
|
|
|
|
newType = f64_ty;
|
|
|
|
|
newOp = rewriter.create<LLVM::FPExtOp>(UnknownLoc::get(context), newType,
|
|
|
|
|
value);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return {newType, newOp};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void llPrintf(StringRef msg, ValueRange args,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
static const char formatStringPrefix[] = "printfFormat_";
|
|
|
|
|
assert(!msg.empty() && "printf with empty string not support");
|
|
|
|
|
Type int8Ptr = ptr_ty(i8_ty);
|
|
|
|
|
|
|
|
|
|
auto *context = rewriter.getContext();
|
|
|
|
|
auto moduleOp =
|
|
|
|
|
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
|
|
|
|
|
auto funcOp = getVprintfDeclaration(rewriter);
|
|
|
|
|
|
|
|
|
|
Value one = rewriter.create<LLVM::ConstantOp>(
|
|
|
|
|
UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(1));
|
|
|
|
|
Value zero = rewriter.create<LLVM::ConstantOp>(
|
|
|
|
|
UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(0));
|
|
|
|
|
|
|
|
|
|
unsigned stringNumber = 0;
|
|
|
|
|
SmallString<16> stringConstName;
|
|
|
|
|
do {
|
|
|
|
|
stringConstName.clear();
|
|
|
|
|
(formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
|
|
|
|
|
} while (moduleOp.lookupSymbol(stringConstName));
|
|
|
|
|
|
|
|
|
|
llvm::SmallString<64> formatString(msg);
|
|
|
|
|
formatString.push_back('\n');
|
|
|
|
|
formatString.push_back('\0');
|
|
|
|
|
size_t formatStringSize = formatString.size_in_bytes();
|
|
|
|
|
auto globalType = LLVM::LLVMArrayType::get(i8_ty, formatStringSize);
|
|
|
|
|
|
|
|
|
|
LLVM::GlobalOp global;
|
|
|
|
|
{
|
|
|
|
|
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
|
|
|
|
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
|
|
|
|
global = rewriter.create<LLVM::GlobalOp>(
|
|
|
|
|
UnknownLoc::get(context), globalType,
|
|
|
|
|
/*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
|
|
|
|
|
rewriter.getStringAttr(formatString));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value globalPtr =
|
|
|
|
|
rewriter.create<LLVM::AddressOfOp>(UnknownLoc::get(context), global);
|
|
|
|
|
Value stringStart =
|
|
|
|
|
rewriter.create<LLVM::GEPOp>(UnknownLoc::get(context), int8Ptr,
|
|
|
|
|
globalPtr, mlir::ValueRange({zero, zero}));
|
|
|
|
|
|
|
|
|
|
Value bufferPtr =
|
|
|
|
|
rewriter.create<LLVM::NullOp>(UnknownLoc::get(context), int8Ptr);
|
|
|
|
|
|
|
|
|
|
SmallVector<Value, 16> newArgs;
|
|
|
|
|
if (args.size() >= 1) {
|
|
|
|
|
SmallVector<Type> argTypes;
|
|
|
|
|
for (auto arg : args) {
|
|
|
|
|
Type newType;
|
|
|
|
|
Value newArg;
|
|
|
|
|
std::tie(newType, newArg) = promoteValue(rewriter, arg);
|
|
|
|
|
argTypes.push_back(newType);
|
|
|
|
|
newArgs.push_back(newArg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Type structTy = LLVM::LLVMStructType::getLiteral(context, argTypes);
|
|
|
|
|
auto allocated = rewriter.create<LLVM::AllocaOp>(UnknownLoc::get(context),
|
|
|
|
|
ptr_ty(structTy), one,
|
|
|
|
|
/*alignment=*/0);
|
|
|
|
|
|
|
|
|
|
for (const auto &entry : llvm::enumerate(newArgs)) {
|
|
|
|
|
auto index = rewriter.create<LLVM::ConstantOp>(
|
|
|
|
|
UnknownLoc::get(context), i32_ty,
|
|
|
|
|
rewriter.getI32IntegerAttr(entry.index()));
|
|
|
|
|
auto fieldPtr = rewriter.create<LLVM::GEPOp>(
|
|
|
|
|
UnknownLoc::get(context), ptr_ty(argTypes[entry.index()]),
|
|
|
|
|
allocated, ArrayRef<Value>{zero, index});
|
|
|
|
|
rewriter.create<LLVM::StoreOp>(UnknownLoc::get(context), entry.value(),
|
|
|
|
|
fieldPtr);
|
|
|
|
|
}
|
|
|
|
|
bufferPtr = rewriter.create<LLVM::BitcastOp>(UnknownLoc::get(context),
|
|
|
|
|
int8Ptr, allocated);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ValueRange operands{stringStart, bufferPtr};
|
|
|
|
|
rewriter.create<LLVM::CallOp>(UnknownLoc::get(context), funcOp, operands);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
|
|
|
|
RewritePatternSet &patterns, int numWarps,
|
|
|
|
|
AxisInfoAnalysis &axisInfoAnalysis,
|
|
|
|
@@ -5062,6 +5092,15 @@ void ConvertTritonGPUToLLVM::initSharedMemory(
|
|
|
|
|
|
|
|
|
|
namespace mlir {
|
|
|
|
|
|
|
|
|
|
namespace LLVM {
|
|
|
|
|
|
|
|
|
|
void llPrintf(StringRef msg, ValueRange args,
|
|
|
|
|
ConversionPatternRewriter &rewriter) {
|
|
|
|
|
PrintfOpConversion::llPrintf(msg, args, rewriter);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace LLVM
|
|
|
|
|
|
|
|
|
|
TritonLLVMConversionTarget::TritonLLVMConversionTarget(
|
|
|
|
|
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter)
|
|
|
|
|
: ConversionTarget(ctx) {
|
|
|
|
|