[Triton-MLIR][BACKEND] Add argmin / argmax implementation for ReduceOp (#918)

This commit is contained in:
Qingyi Liu
2022-11-28 14:59:27 +08:00
committed by GitHub
parent 04ec5deb41
commit 9d31998a9d
12 changed files with 341 additions and 75 deletions

View File

@@ -88,25 +88,6 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
return paddedRepShape;
}
SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op) {
ReduceOpHelper helper(op);
SmallVector<unsigned> smemShape;
auto srcShape = helper.getSrcShape();
for (auto d : srcShape)
smemShape.push_back(d);
auto axis = op.axis();
if (helper.isFastReduction()) {
smemShape[axis] = helper.getInterWarpSize();
} else {
smemShape[axis] =
std::min(smemShape[axis], helper.getThreadsReductionAxis());
}
return smemShape;
}
// TODO: extend beyond scalars
SmallVector<unsigned> getScratchConfigForAtomicRMW(triton::AtomicRMWOp op) {
SmallVector<unsigned> smemShape;
@@ -173,21 +154,9 @@ private:
/// Initializes temporary shared memory for a given operation.
void getScratchValueSize(Operation *op) {
if (auto reduceOp = dyn_cast<triton::ReduceOp>(op)) {
// TODO(Keren): Reduce with index is not supported yet.
auto value = op->getOperand(0);
if (auto tensorType = value.getType().dyn_cast<RankedTensorType>()) {
bool fastReduce = ReduceOpHelper(reduceOp).isFastReduction();
auto smemShape = getScratchConfigForReduce(reduceOp);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
std::multiplies{});
if (fastReduce) {
auto mod = op->getParentOfType<ModuleOp>();
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
elems = std::max<unsigned>(elems, numWarps * 32);
}
auto bytes = elems * tensorType.getElementTypeBitWidth() / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
}
ReduceOpHelper helper(reduceOp);
unsigned bytes = helper.getScratchSizeInBytes();
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
} else if (auto cvtLayout = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
auto srcTy = cvtLayout.src().getType().cast<RankedTensorType>();
auto dstTy = cvtLayout.result().getType().cast<RankedTensorType>();

View File

@@ -37,6 +37,55 @@ unsigned ReduceOpHelper::getThreadsReductionAxis() {
triton::gpu::getWarpsPerCTA(srcLayout)[axis];
}
SmallVector<unsigned> ReduceOpHelper::getScratchConfigBasic() {
auto axis = op.axis();
auto smemShape = convertType<unsigned>(getSrcShape());
smemShape[axis] = std::min(smemShape[axis], getThreadsReductionAxis());
return smemShape;
}
SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
auto axis = op.axis();
SmallVector<SmallVector<unsigned>> smemShapes(3);
/// shared memory block0
smemShapes[0] = convertType<unsigned>(getSrcShape());
smemShapes[0][axis] = getInterWarpSize();
/// FIXME(Qingyi): This size is actually larger than required.
/// shared memory block1:
auto mod = op.getOperation()->getParentOfType<ModuleOp>();
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
smemShapes[1].push_back(numWarps * 32);
/// FIXME(Qingyi): This requirement is actually not necessary, because it is
/// always smaller than smemShapes[0] shared memory block2
smemShapes[2] = convertType<unsigned>(getSrcShape());
smemShapes[2].erase(smemShapes[2].begin() + axis);
return smemShapes;
}
unsigned ReduceOpHelper::getScratchSizeInBytes() {
unsigned elems = 0;
if (isFastReduction()) {
auto smemShapes = getScratchConfigsFast();
for (const auto &smemShape : smemShapes)
elems = std::max(elems, product<unsigned>(smemShape));
} else {
auto smemShape = getScratchConfigBasic();
elems = product<unsigned>(smemShape);
}
auto tensorType = op.operand().getType().cast<RankedTensorType>();
unsigned bytes = elems * tensorType.getElementTypeBitWidth() / 8;
if (triton::ReduceOp::withIndex(op.redOp()))
bytes += elems * sizeof(int32_t);
return bytes;
}
bool isSharedEncoding(Value value) {
auto type = value.getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {

View File

@@ -1338,6 +1338,10 @@ private:
void accumulate(ConversionPatternRewriter &rewriter, Location loc,
RedOp redOp, Value &acc, Value cur, bool isFirst) const;
void accumulateWithIndex(ConversionPatternRewriter &rewriter, Location loc,
RedOp redOp, Value &acc, Value &accIndex, Value cur,
Value curIndex, bool isFirst) const;
Value shflSync(ConversionPatternRewriter &rewriter, Location loc, Value val,
int i) const;
@@ -1366,7 +1370,6 @@ void ReduceOpConversion::accumulate(ConversionPatternRewriter &rewriter,
acc = cur;
return;
}
auto type = cur.getType();
switch (redOp) {
case RedOp::ADD:
acc = add(acc, cur);
@@ -1395,6 +1398,75 @@ void ReduceOpConversion::accumulate(ConversionPatternRewriter &rewriter,
case RedOp::XOR:
acc = xor_(acc, cur);
break;
case RedOp::ARGMIN:
case RedOp::ARGMAX:
case RedOp::ARGUMIN:
case RedOp::ARGUMAX:
case RedOp::ARGFMIN:
case RedOp::ARGFMAX:
llvm::report_fatal_error(
"This accumulate implementation is not for argmin / argmax");
default:
llvm::report_fatal_error("Unsupported reduce op");
}
}
void ReduceOpConversion::accumulateWithIndex(
ConversionPatternRewriter &rewriter, Location loc, RedOp redOp, Value &acc,
Value &accIndex, Value cur, Value curIndex, bool isFirst) const {
if (isFirst) {
acc = cur;
accIndex = curIndex;
return;
}
switch (redOp) {
case RedOp::ARGMIN:
accIndex =
select(icmp_slt(acc, cur), accIndex,
select(icmp_sgt(acc, cur), curIndex, smin(accIndex, curIndex)));
acc = smin(acc, cur);
break;
case RedOp::ARGMAX:
accIndex =
select(icmp_sgt(acc, cur), accIndex,
select(icmp_slt(acc, cur), curIndex, smin(accIndex, curIndex)));
acc = smax(acc, cur);
break;
case RedOp::ARGUMIN:
accIndex =
select(icmp_ult(acc, cur), accIndex,
select(icmp_ugt(acc, cur), curIndex, smin(accIndex, curIndex)));
acc = umin(acc, cur);
break;
case RedOp::ARGUMAX:
accIndex =
select(icmp_ugt(acc, cur), accIndex,
select(icmp_ult(acc, cur), curIndex, smin(accIndex, curIndex)));
acc = umax(acc, cur);
break;
case RedOp::ARGFMIN:
accIndex =
select(fcmp_olt(acc, cur), accIndex,
select(fcmp_ogt(acc, cur), curIndex, smin(accIndex, curIndex)));
acc = fmin(acc, cur);
break;
case RedOp::ARGFMAX:
accIndex =
select(fcmp_ogt(acc, cur), accIndex,
select(fcmp_olt(acc, cur), curIndex, smin(accIndex, curIndex)));
acc = fmax(acc, cur);
break;
case RedOp::ADD:
case RedOp::FADD:
case RedOp::MIN:
case RedOp::MAX:
case RedOp::UMIN:
case RedOp::UMAX:
case RedOp::FMIN:
case RedOp::FMAX:
case RedOp::XOR:
llvm::report_fatal_error(
"This accumulate implementation is only for argmin / argmax");
default:
llvm::report_fatal_error("Unsupported reduce op");
}
@@ -1433,6 +1505,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
ConversionPatternRewriter &rewriter) const {
Location loc = op->getLoc();
unsigned axis = op.axis();
bool withIndex = triton::ReduceOp::withIndex(op.redOp());
auto srcTy = op.operand().getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
@@ -1440,11 +1513,17 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
auto srcShape = srcTy.getShape();
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
auto llvmIndexTy = getTypeConverter()->getIndexType();
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3);
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
smemBase = bitcast(smemBase, elemPtrTy);
auto smemShape = getScratchConfigForReduce(op);
ReduceOpHelper helper(op);
auto smemShape = helper.getScratchConfigBasic();
unsigned elems = product<unsigned>(smemShape);
Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(elems));
indexSmemBase = bitcast(indexSmemBase, indexPtrTy);
unsigned srcElems = getElemsPerThread(srcTy);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
@@ -1454,6 +1533,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
emitOffsetForBlockedLayout(srcLayout, srcShape);
std::map<SmallVector<unsigned>, Value> accs;
std::map<SmallVector<unsigned>, Value> accIndices;
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
// reduce within threads
@@ -1461,7 +1541,13 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
SmallVector<unsigned> key = offset[i];
key[axis] = 0;
bool isFirst = accs.find(key) == accs.end();
accumulate(rewriter, loc, op.redOp(), accs[key], srcValues[i], isFirst);
if (!withIndex) {
accumulate(rewriter, loc, op.redOp(), accs[key], srcValues[i], isFirst);
} else {
Value curIndex = srcIndices[i][axis];
accumulateWithIndex(rewriter, loc, op.redOp(), accs[key], accIndices[key],
srcValues[i], curIndex, isFirst);
}
if (isFirst)
indices[key] = srcIndices[i];
}
@@ -1477,12 +1563,18 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
for (auto it : accs) {
const SmallVector<unsigned> &key = it.first;
Value acc = it.second;
Value accIndex;
if (withIndex)
accIndex = accIndices[key];
SmallVector<Value> writeIdx = indices[key];
writeIdx[axis] = udiv(writeIdx[axis], sizePerThread);
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, srcOrd);
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
Value indexWritePtr = gep(indexPtrTy, indexSmemBase, writeOffset);
store(acc, writePtr);
if (withIndex)
store(accIndex, indexWritePtr);
SmallVector<Value> readIdx(writeIdx.size(), ints[0]);
for (int N = smemShape[axis] / 2; N > 0; N >>= 1) {
@@ -1493,11 +1585,24 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
ints[0]);
Value readPtr = gep(elemPtrTy, writePtr, readOffset);
barrier();
accumulate(rewriter, loc, op.redOp(), acc, load(readPtr), false);
store(acc, writePtr);
if (!withIndex) {
Value cur = load(readPtr);
accumulate(rewriter, loc, op.redOp(), acc, cur, false);
store(acc, writePtr);
} else {
Value cur = load(readPtr);
Value indexReadPtr = gep(indexPtrTy, indexWritePtr, readOffset);
Value curIndex = load(indexReadPtr);
accumulateWithIndex(rewriter, loc, op.redOp(), acc, accIndex, cur,
curIndex, false);
store(acc, writePtr);
store(accIndex, indexWritePtr);
}
}
}
barrier();
// set output values
if (auto resultTy = op.getType().dyn_cast<RankedTensorType>()) {
// nd-tensor where n >= 1
@@ -1508,25 +1613,25 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultShape);
assert(resultIndices.size() == resultElems);
barrier();
SmallVector<Value> resultVals(resultElems);
for (unsigned i = 0; i < resultElems; ++i) {
SmallVector<Value> readIdx = resultIndices[i];
readIdx.insert(readIdx.begin() + axis, ints[0]);
Value readOffset = linearize(rewriter, loc, readIdx, smemShape, srcOrd);
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
resultVals[i] = load(readPtr);
Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset);
resultVals[i] = withIndex ? load(indexReadPtr) : load(readPtr);
}
SmallVector<Type> resultTypes(resultElems, llvmElemTy);
SmallVector<Type> resultTypes(resultElems,
withIndex ? llvmIndexTy : llvmElemTy);
Type structTy =
LLVM::LLVMStructType::getLiteral(this->getContext(), resultTypes);
Value ret = getStructFromElements(loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, ret);
} else {
// 0d-tensor -> scalar
barrier();
Value resultVal = load(smemBase);
Value resultVal = withIndex ? load(indexSmemBase) : load(smemBase);
rewriter.replaceOp(op, resultVal);
}
@@ -1538,25 +1643,35 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
ConversionPatternRewriter &rewriter) const {
Location loc = op->getLoc();
unsigned axis = adaptor.axis();
bool withIndex = triton::ReduceOp::withIndex(op.redOp());
auto srcTy = op.operand().getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding();
auto srcShape = srcTy.getShape();
auto srcRank = srcTy.getRank();
auto order = getOrder(srcLayout);
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcLayout);
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
auto llvmIndexTy = getTypeConverter()->getIndexType();
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3);
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
smemBase = bitcast(smemBase, elemPtrTy);
ReduceOpHelper helper(op);
auto smemShapes = helper.getScratchConfigsFast();
unsigned elems = product<unsigned>(smemShapes[0]);
unsigned maxElems = std::max(elems, product<unsigned>(smemShapes[1]));
maxElems = std::max(maxElems, product<unsigned>(smemShapes[2]));
Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(maxElems));
indexSmemBase = bitcast(indexSmemBase, indexPtrTy);
unsigned sizeIntraWarps = helper.getIntraWarpSize();
unsigned sizeInterWarps = helper.getInterWarpSize();
auto order = getOrder(srcLayout);
unsigned srcElems = getElemsPerThread(srcTy);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter);
@@ -1565,16 +1680,21 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
emitOffsetForLayout(srcLayout, srcShape);
std::map<SmallVector<unsigned>, Value> accs;
std::map<SmallVector<unsigned>, Value> accIndices;
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
auto smemShape = getScratchConfigForReduce(op);
// reduce within threads
for (unsigned i = 0; i < srcElems; ++i) {
SmallVector<unsigned> key = offset[i];
key[axis] = 0;
bool isFirst = accs.find(key) == accs.end();
accumulate(rewriter, loc, op.redOp(), accs[key], srcValues[i], isFirst);
if (!withIndex) {
accumulate(rewriter, loc, op.redOp(), accs[key], srcValues[i], isFirst);
} else {
Value curIndex = srcIndices[i][axis];
accumulateWithIndex(rewriter, loc, op.redOp(), accs[key], accIndices[key],
srcValues[i], curIndex, isFirst);
}
if (isFirst)
indices[key] = srcIndices[i];
}
@@ -1599,18 +1719,32 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
for (auto it : accs) {
const SmallVector<unsigned> &key = it.first;
Value acc = it.second;
Value accIndex;
if (withIndex)
accIndex = accIndices[key];
// reduce within warps
for (unsigned N = sizeIntraWarps / 2; N > 0; N >>= 1) {
Value shfl = shflSync(rewriter, loc, acc, N);
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
if (!withIndex) {
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
} else {
Value shflIndex = shflSync(rewriter, loc, accIndex, N);
accumulateWithIndex(rewriter, loc, op.redOp(), acc, accIndex, shfl,
shflIndex, false);
}
}
SmallVector<Value> writeIdx = indices[key];
writeIdx[axis] = (sizeInterWarps == 1) ? zero : warpIdAxis;
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, order);
Value writeOffset =
linearize(rewriter, loc, writeIdx, smemShapes[0], order);
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
storeShared(rewriter, loc, writePtr, acc, laneZero);
if (withIndex) {
Value indexWritePtr = gep(indexPtrTy, indexSmemBase, writeOffset);
storeShared(rewriter, loc, indexWritePtr, accIndex, laneZero);
}
}
barrier();
@@ -1622,7 +1756,6 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
//
// each thread needs to process:
// elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads
unsigned elems = product<unsigned>(smemShape);
unsigned numThreads =
product<unsigned>(triton::gpu::getWarpsPerCTA(srcLayout)) * 32;
unsigned elemsPerThread = std::max<unsigned>(elems / numThreads, 1);
@@ -1630,10 +1763,21 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
for (unsigned round = 0; round < elemsPerThread; ++round) {
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
Value acc = load(readPtr);
Value accIndex;
if (withIndex) {
Value readIndexPtr = gep(indexPtrTy, indexSmemBase, readOffset);
accIndex = load(readIndexPtr);
}
for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) {
Value shfl = shflSync(rewriter, loc, acc, N);
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
if (!withIndex) {
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
} else {
Value shflIndex = shflSync(rewriter, loc, accIndex, N);
accumulateWithIndex(rewriter, loc, op.redOp(), acc, accIndex, shfl,
shflIndex, false);
}
}
Value writeOffset = udiv(readOffset, i32_val(sizeInterWarps));
@@ -1642,8 +1786,12 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps));
Value laneIdModSizeInterWarpsIsZero =
icmp_eq(laneIdModSizeInterWarps, zero);
storeShared(rewriter, loc, writePtr, acc,
and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero));
Value pred = and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero);
storeShared(rewriter, loc, writePtr, acc, pred);
if (withIndex) {
Value writeIndexPtr = gep(indexPtrTy, indexSmemBase, writeOffset);
storeShared(rewriter, loc, writeIndexPtr, accIndex, pred);
}
if (round != elemsPerThread - 1) {
readOffset = add(readOffset, i32_val(numThreads));
@@ -1671,25 +1819,24 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
assert(resultIndices.size() == resultElems);
SmallVector<Value> resultVals(resultElems);
SmallVector<unsigned> resultShape;
std::copy(resultTy.getShape().begin(), resultTy.getShape().end(),
std::back_inserter(resultShape));
for (size_t i = 0; i < resultElems; ++i) {
SmallVector<Value> readIdx = resultIndices[i];
Value readOffset =
linearize(rewriter, loc, readIdx, resultShape, resultOrd);
linearize(rewriter, loc, readIdx, smemShapes[2], resultOrd);
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
resultVals[i] = load(readPtr);
Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset);
resultVals[i] = withIndex ? load(indexReadPtr) : load(readPtr);
}
SmallVector<Type> resultTypes(resultElems, llvmElemTy);
SmallVector<Type> resultTypes(resultElems,
withIndex ? llvmIndexTy : llvmElemTy);
Type structTy =
LLVM::LLVMStructType::getLiteral(this->getContext(), resultTypes);
Value ret = getStructFromElements(loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, ret);
} else {
// 0d-tensor -> scalar
Value resultVal = load(smemBase);
Value resultVal = withIndex ? load(indexSmemBase) : load(smemBase);
rewriter.replaceOp(op, resultVal);
}

View File

@@ -60,12 +60,32 @@
rewriter.create<LLVM::ExtractElementOp>(loc, __VA_ARGS__)
#define load(...) rewriter.create<LLVM::LoadOp>(loc, __VA_ARGS__)
#define store(val, ptr) rewriter.create<LLVM::StoreOp>(loc, val, ptr)
#define fcmp_ogt(lhs, rhs) \
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
LLVM::FCmpPredicate::ogt, lhs, rhs)
#define fcmp_olt(lhs, rhs) \
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
LLVM::FCmpPredicate::olt, lhs, rhs)
#define icmp_eq(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__)
#define icmp_ne(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, __VA_ARGS__)
#define icmp_slt(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, __VA_ARGS__)
#define icmp_sle(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sle, __VA_ARGS__)
#define icmp_sgt(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, __VA_ARGS__)
#define icmp_sge(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sge, __VA_ARGS__)
#define icmp_ult(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ult, __VA_ARGS__)
#define icmp_ule(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ule, __VA_ARGS__)
#define icmp_ugt(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ugt, __VA_ARGS__)
#define icmp_uge(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::uge, __VA_ARGS__)
#define select(...) rewriter.create<LLVM::SelectOp>(loc, __VA_ARGS__)
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)

View File

@@ -240,12 +240,16 @@ mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
Value arg = operands[0];
auto argTy = arg.getType().cast<RankedTensorType>();
auto argEltTy = argTy.getElementType();
auto i32Ty = IntegerType::get(argEltTy.getContext(), 32);
auto redOp = attributes.get("redOp").cast<mlir::triton::RedOpAttr>().getValue();
bool withIndex = mlir::triton::ReduceOp::withIndex(redOp);
auto retEltTy = withIndex ? i32Ty : argEltTy;
auto retShape = argTy.getShape().vec();
int axis = attributes.get("axis").cast<IntegerAttr>().getInt();
retShape.erase(retShape.begin() + axis);
if (retShape.empty()) {
// 0d-tensor -> scalar
inferredReturnTypes.push_back(argEltTy);
inferredReturnTypes.push_back(retEltTy);
} else {
// nd-tensor where n >= 1
// infer encoding
@@ -264,11 +268,20 @@ mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
}
// create type
inferredReturnTypes.push_back(
RankedTensorType::get(retShape, argEltTy, retEncoding));
RankedTensorType::get(retShape, retEltTy, retEncoding));
}
return mlir::success();
}
bool mlir::triton::ReduceOp::withIndex(mlir::triton::RedOp redOp) {
return redOp == mlir::triton::RedOp::ARGMIN ||
redOp == mlir::triton::RedOp::ARGMAX ||
redOp == mlir::triton::RedOp::ARGUMIN ||
redOp == mlir::triton::RedOp::ARGUMAX ||
redOp == mlir::triton::RedOp::ARGFMIN ||
redOp == mlir::triton::RedOp::ARGFMAX;
}
//-- SplatOp --
OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
auto constOperand = src().getDefiningOp<arith::ConstantOp>();