[Triton-MLIR][BACKEND] Add argmin / argmax implementation for ReduceOp (#918)
This commit is contained in:
@@ -1338,6 +1338,10 @@ private:
|
||||
void accumulate(ConversionPatternRewriter &rewriter, Location loc,
|
||||
RedOp redOp, Value &acc, Value cur, bool isFirst) const;
|
||||
|
||||
void accumulateWithIndex(ConversionPatternRewriter &rewriter, Location loc,
|
||||
RedOp redOp, Value &acc, Value &accIndex, Value cur,
|
||||
Value curIndex, bool isFirst) const;
|
||||
|
||||
Value shflSync(ConversionPatternRewriter &rewriter, Location loc, Value val,
|
||||
int i) const;
|
||||
|
||||
@@ -1366,7 +1370,6 @@ void ReduceOpConversion::accumulate(ConversionPatternRewriter &rewriter,
|
||||
acc = cur;
|
||||
return;
|
||||
}
|
||||
auto type = cur.getType();
|
||||
switch (redOp) {
|
||||
case RedOp::ADD:
|
||||
acc = add(acc, cur);
|
||||
@@ -1395,6 +1398,75 @@ void ReduceOpConversion::accumulate(ConversionPatternRewriter &rewriter,
|
||||
case RedOp::XOR:
|
||||
acc = xor_(acc, cur);
|
||||
break;
|
||||
case RedOp::ARGMIN:
|
||||
case RedOp::ARGMAX:
|
||||
case RedOp::ARGUMIN:
|
||||
case RedOp::ARGUMAX:
|
||||
case RedOp::ARGFMIN:
|
||||
case RedOp::ARGFMAX:
|
||||
llvm::report_fatal_error(
|
||||
"This accumulate implementation is not for argmin / argmax");
|
||||
default:
|
||||
llvm::report_fatal_error("Unsupported reduce op");
|
||||
}
|
||||
}
|
||||
|
||||
void ReduceOpConversion::accumulateWithIndex(
|
||||
ConversionPatternRewriter &rewriter, Location loc, RedOp redOp, Value &acc,
|
||||
Value &accIndex, Value cur, Value curIndex, bool isFirst) const {
|
||||
if (isFirst) {
|
||||
acc = cur;
|
||||
accIndex = curIndex;
|
||||
return;
|
||||
}
|
||||
switch (redOp) {
|
||||
case RedOp::ARGMIN:
|
||||
accIndex =
|
||||
select(icmp_slt(acc, cur), accIndex,
|
||||
select(icmp_sgt(acc, cur), curIndex, smin(accIndex, curIndex)));
|
||||
acc = smin(acc, cur);
|
||||
break;
|
||||
case RedOp::ARGMAX:
|
||||
accIndex =
|
||||
select(icmp_sgt(acc, cur), accIndex,
|
||||
select(icmp_slt(acc, cur), curIndex, smin(accIndex, curIndex)));
|
||||
acc = smax(acc, cur);
|
||||
break;
|
||||
case RedOp::ARGUMIN:
|
||||
accIndex =
|
||||
select(icmp_ult(acc, cur), accIndex,
|
||||
select(icmp_ugt(acc, cur), curIndex, smin(accIndex, curIndex)));
|
||||
acc = umin(acc, cur);
|
||||
break;
|
||||
case RedOp::ARGUMAX:
|
||||
accIndex =
|
||||
select(icmp_ugt(acc, cur), accIndex,
|
||||
select(icmp_ult(acc, cur), curIndex, smin(accIndex, curIndex)));
|
||||
acc = umax(acc, cur);
|
||||
break;
|
||||
case RedOp::ARGFMIN:
|
||||
accIndex =
|
||||
select(fcmp_olt(acc, cur), accIndex,
|
||||
select(fcmp_ogt(acc, cur), curIndex, smin(accIndex, curIndex)));
|
||||
acc = fmin(acc, cur);
|
||||
break;
|
||||
case RedOp::ARGFMAX:
|
||||
accIndex =
|
||||
select(fcmp_ogt(acc, cur), accIndex,
|
||||
select(fcmp_olt(acc, cur), curIndex, smin(accIndex, curIndex)));
|
||||
acc = fmax(acc, cur);
|
||||
break;
|
||||
case RedOp::ADD:
|
||||
case RedOp::FADD:
|
||||
case RedOp::MIN:
|
||||
case RedOp::MAX:
|
||||
case RedOp::UMIN:
|
||||
case RedOp::UMAX:
|
||||
case RedOp::FMIN:
|
||||
case RedOp::FMAX:
|
||||
case RedOp::XOR:
|
||||
llvm::report_fatal_error(
|
||||
"This accumulate implementation is only for argmin / argmax");
|
||||
default:
|
||||
llvm::report_fatal_error("Unsupported reduce op");
|
||||
}
|
||||
@@ -1433,6 +1505,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Location loc = op->getLoc();
|
||||
unsigned axis = op.axis();
|
||||
bool withIndex = triton::ReduceOp::withIndex(op.redOp());
|
||||
|
||||
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
||||
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
@@ -1440,11 +1513,17 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
||||
auto srcShape = srcTy.getShape();
|
||||
|
||||
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
|
||||
auto llvmIndexTy = getTypeConverter()->getIndexType();
|
||||
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||
auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3);
|
||||
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
smemBase = bitcast(smemBase, elemPtrTy);
|
||||
|
||||
auto smemShape = getScratchConfigForReduce(op);
|
||||
ReduceOpHelper helper(op);
|
||||
auto smemShape = helper.getScratchConfigBasic();
|
||||
unsigned elems = product<unsigned>(smemShape);
|
||||
Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(elems));
|
||||
indexSmemBase = bitcast(indexSmemBase, indexPtrTy);
|
||||
|
||||
unsigned srcElems = getElemsPerThread(srcTy);
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
|
||||
@@ -1454,6 +1533,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
||||
emitOffsetForBlockedLayout(srcLayout, srcShape);
|
||||
|
||||
std::map<SmallVector<unsigned>, Value> accs;
|
||||
std::map<SmallVector<unsigned>, Value> accIndices;
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
|
||||
|
||||
// reduce within threads
|
||||
@@ -1461,7 +1541,13 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
||||
SmallVector<unsigned> key = offset[i];
|
||||
key[axis] = 0;
|
||||
bool isFirst = accs.find(key) == accs.end();
|
||||
accumulate(rewriter, loc, op.redOp(), accs[key], srcValues[i], isFirst);
|
||||
if (!withIndex) {
|
||||
accumulate(rewriter, loc, op.redOp(), accs[key], srcValues[i], isFirst);
|
||||
} else {
|
||||
Value curIndex = srcIndices[i][axis];
|
||||
accumulateWithIndex(rewriter, loc, op.redOp(), accs[key], accIndices[key],
|
||||
srcValues[i], curIndex, isFirst);
|
||||
}
|
||||
if (isFirst)
|
||||
indices[key] = srcIndices[i];
|
||||
}
|
||||
@@ -1477,12 +1563,18 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
||||
for (auto it : accs) {
|
||||
const SmallVector<unsigned> &key = it.first;
|
||||
Value acc = it.second;
|
||||
Value accIndex;
|
||||
if (withIndex)
|
||||
accIndex = accIndices[key];
|
||||
SmallVector<Value> writeIdx = indices[key];
|
||||
|
||||
writeIdx[axis] = udiv(writeIdx[axis], sizePerThread);
|
||||
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, srcOrd);
|
||||
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
||||
Value indexWritePtr = gep(indexPtrTy, indexSmemBase, writeOffset);
|
||||
store(acc, writePtr);
|
||||
if (withIndex)
|
||||
store(accIndex, indexWritePtr);
|
||||
|
||||
SmallVector<Value> readIdx(writeIdx.size(), ints[0]);
|
||||
for (int N = smemShape[axis] / 2; N > 0; N >>= 1) {
|
||||
@@ -1493,11 +1585,24 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
||||
ints[0]);
|
||||
Value readPtr = gep(elemPtrTy, writePtr, readOffset);
|
||||
barrier();
|
||||
accumulate(rewriter, loc, op.redOp(), acc, load(readPtr), false);
|
||||
store(acc, writePtr);
|
||||
if (!withIndex) {
|
||||
Value cur = load(readPtr);
|
||||
accumulate(rewriter, loc, op.redOp(), acc, cur, false);
|
||||
store(acc, writePtr);
|
||||
} else {
|
||||
Value cur = load(readPtr);
|
||||
Value indexReadPtr = gep(indexPtrTy, indexWritePtr, readOffset);
|
||||
Value curIndex = load(indexReadPtr);
|
||||
accumulateWithIndex(rewriter, loc, op.redOp(), acc, accIndex, cur,
|
||||
curIndex, false);
|
||||
store(acc, writePtr);
|
||||
store(accIndex, indexWritePtr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
// set output values
|
||||
if (auto resultTy = op.getType().dyn_cast<RankedTensorType>()) {
|
||||
// nd-tensor where n >= 1
|
||||
@@ -1508,25 +1613,25 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
||||
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultShape);
|
||||
assert(resultIndices.size() == resultElems);
|
||||
|
||||
barrier();
|
||||
SmallVector<Value> resultVals(resultElems);
|
||||
for (unsigned i = 0; i < resultElems; ++i) {
|
||||
SmallVector<Value> readIdx = resultIndices[i];
|
||||
readIdx.insert(readIdx.begin() + axis, ints[0]);
|
||||
Value readOffset = linearize(rewriter, loc, readIdx, smemShape, srcOrd);
|
||||
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
||||
resultVals[i] = load(readPtr);
|
||||
Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset);
|
||||
resultVals[i] = withIndex ? load(indexReadPtr) : load(readPtr);
|
||||
}
|
||||
|
||||
SmallVector<Type> resultTypes(resultElems, llvmElemTy);
|
||||
SmallVector<Type> resultTypes(resultElems,
|
||||
withIndex ? llvmIndexTy : llvmElemTy);
|
||||
Type structTy =
|
||||
LLVM::LLVMStructType::getLiteral(this->getContext(), resultTypes);
|
||||
Value ret = getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, ret);
|
||||
} else {
|
||||
// 0d-tensor -> scalar
|
||||
barrier();
|
||||
Value resultVal = load(smemBase);
|
||||
Value resultVal = withIndex ? load(indexSmemBase) : load(smemBase);
|
||||
rewriter.replaceOp(op, resultVal);
|
||||
}
|
||||
|
||||
@@ -1538,25 +1643,35 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Location loc = op->getLoc();
|
||||
unsigned axis = adaptor.axis();
|
||||
bool withIndex = triton::ReduceOp::withIndex(op.redOp());
|
||||
|
||||
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto srcRank = srcTy.getRank();
|
||||
auto order = getOrder(srcLayout);
|
||||
|
||||
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout);
|
||||
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcLayout);
|
||||
|
||||
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
|
||||
auto llvmIndexTy = getTypeConverter()->getIndexType();
|
||||
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||
auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3);
|
||||
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
smemBase = bitcast(smemBase, elemPtrTy);
|
||||
|
||||
ReduceOpHelper helper(op);
|
||||
auto smemShapes = helper.getScratchConfigsFast();
|
||||
unsigned elems = product<unsigned>(smemShapes[0]);
|
||||
unsigned maxElems = std::max(elems, product<unsigned>(smemShapes[1]));
|
||||
maxElems = std::max(maxElems, product<unsigned>(smemShapes[2]));
|
||||
Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(maxElems));
|
||||
indexSmemBase = bitcast(indexSmemBase, indexPtrTy);
|
||||
|
||||
unsigned sizeIntraWarps = helper.getIntraWarpSize();
|
||||
unsigned sizeInterWarps = helper.getInterWarpSize();
|
||||
|
||||
auto order = getOrder(srcLayout);
|
||||
unsigned srcElems = getElemsPerThread(srcTy);
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
|
||||
auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter);
|
||||
@@ -1565,16 +1680,21 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
emitOffsetForLayout(srcLayout, srcShape);
|
||||
|
||||
std::map<SmallVector<unsigned>, Value> accs;
|
||||
std::map<SmallVector<unsigned>, Value> accIndices;
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
|
||||
|
||||
auto smemShape = getScratchConfigForReduce(op);
|
||||
|
||||
// reduce within threads
|
||||
for (unsigned i = 0; i < srcElems; ++i) {
|
||||
SmallVector<unsigned> key = offset[i];
|
||||
key[axis] = 0;
|
||||
bool isFirst = accs.find(key) == accs.end();
|
||||
accumulate(rewriter, loc, op.redOp(), accs[key], srcValues[i], isFirst);
|
||||
if (!withIndex) {
|
||||
accumulate(rewriter, loc, op.redOp(), accs[key], srcValues[i], isFirst);
|
||||
} else {
|
||||
Value curIndex = srcIndices[i][axis];
|
||||
accumulateWithIndex(rewriter, loc, op.redOp(), accs[key], accIndices[key],
|
||||
srcValues[i], curIndex, isFirst);
|
||||
}
|
||||
if (isFirst)
|
||||
indices[key] = srcIndices[i];
|
||||
}
|
||||
@@ -1599,18 +1719,32 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
for (auto it : accs) {
|
||||
const SmallVector<unsigned> &key = it.first;
|
||||
Value acc = it.second;
|
||||
Value accIndex;
|
||||
if (withIndex)
|
||||
accIndex = accIndices[key];
|
||||
|
||||
// reduce within warps
|
||||
for (unsigned N = sizeIntraWarps / 2; N > 0; N >>= 1) {
|
||||
Value shfl = shflSync(rewriter, loc, acc, N);
|
||||
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
|
||||
if (!withIndex) {
|
||||
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
|
||||
} else {
|
||||
Value shflIndex = shflSync(rewriter, loc, accIndex, N);
|
||||
accumulateWithIndex(rewriter, loc, op.redOp(), acc, accIndex, shfl,
|
||||
shflIndex, false);
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<Value> writeIdx = indices[key];
|
||||
writeIdx[axis] = (sizeInterWarps == 1) ? zero : warpIdAxis;
|
||||
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, order);
|
||||
Value writeOffset =
|
||||
linearize(rewriter, loc, writeIdx, smemShapes[0], order);
|
||||
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
||||
storeShared(rewriter, loc, writePtr, acc, laneZero);
|
||||
if (withIndex) {
|
||||
Value indexWritePtr = gep(indexPtrTy, indexSmemBase, writeOffset);
|
||||
storeShared(rewriter, loc, indexWritePtr, accIndex, laneZero);
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
@@ -1622,7 +1756,6 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
//
|
||||
// each thread needs to process:
|
||||
// elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads
|
||||
unsigned elems = product<unsigned>(smemShape);
|
||||
unsigned numThreads =
|
||||
product<unsigned>(triton::gpu::getWarpsPerCTA(srcLayout)) * 32;
|
||||
unsigned elemsPerThread = std::max<unsigned>(elems / numThreads, 1);
|
||||
@@ -1630,10 +1763,21 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
for (unsigned round = 0; round < elemsPerThread; ++round) {
|
||||
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
||||
Value acc = load(readPtr);
|
||||
Value accIndex;
|
||||
if (withIndex) {
|
||||
Value readIndexPtr = gep(indexPtrTy, indexSmemBase, readOffset);
|
||||
accIndex = load(readIndexPtr);
|
||||
}
|
||||
|
||||
for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) {
|
||||
Value shfl = shflSync(rewriter, loc, acc, N);
|
||||
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
|
||||
if (!withIndex) {
|
||||
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
|
||||
} else {
|
||||
Value shflIndex = shflSync(rewriter, loc, accIndex, N);
|
||||
accumulateWithIndex(rewriter, loc, op.redOp(), acc, accIndex, shfl,
|
||||
shflIndex, false);
|
||||
}
|
||||
}
|
||||
|
||||
Value writeOffset = udiv(readOffset, i32_val(sizeInterWarps));
|
||||
@@ -1642,8 +1786,12 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps));
|
||||
Value laneIdModSizeInterWarpsIsZero =
|
||||
icmp_eq(laneIdModSizeInterWarps, zero);
|
||||
storeShared(rewriter, loc, writePtr, acc,
|
||||
and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero));
|
||||
Value pred = and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero);
|
||||
storeShared(rewriter, loc, writePtr, acc, pred);
|
||||
if (withIndex) {
|
||||
Value writeIndexPtr = gep(indexPtrTy, indexSmemBase, writeOffset);
|
||||
storeShared(rewriter, loc, writeIndexPtr, accIndex, pred);
|
||||
}
|
||||
|
||||
if (round != elemsPerThread - 1) {
|
||||
readOffset = add(readOffset, i32_val(numThreads));
|
||||
@@ -1671,25 +1819,24 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
assert(resultIndices.size() == resultElems);
|
||||
|
||||
SmallVector<Value> resultVals(resultElems);
|
||||
SmallVector<unsigned> resultShape;
|
||||
std::copy(resultTy.getShape().begin(), resultTy.getShape().end(),
|
||||
std::back_inserter(resultShape));
|
||||
for (size_t i = 0; i < resultElems; ++i) {
|
||||
SmallVector<Value> readIdx = resultIndices[i];
|
||||
Value readOffset =
|
||||
linearize(rewriter, loc, readIdx, resultShape, resultOrd);
|
||||
linearize(rewriter, loc, readIdx, smemShapes[2], resultOrd);
|
||||
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
||||
resultVals[i] = load(readPtr);
|
||||
Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset);
|
||||
resultVals[i] = withIndex ? load(indexReadPtr) : load(readPtr);
|
||||
}
|
||||
|
||||
SmallVector<Type> resultTypes(resultElems, llvmElemTy);
|
||||
SmallVector<Type> resultTypes(resultElems,
|
||||
withIndex ? llvmIndexTy : llvmElemTy);
|
||||
Type structTy =
|
||||
LLVM::LLVMStructType::getLiteral(this->getContext(), resultTypes);
|
||||
Value ret = getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, ret);
|
||||
} else {
|
||||
// 0d-tensor -> scalar
|
||||
Value resultVal = load(smemBase);
|
||||
Value resultVal = withIndex ? load(indexSmemBase) : load(smemBase);
|
||||
rewriter.replaceOp(op, resultVal);
|
||||
}
|
||||
|
||||
|
@@ -60,12 +60,32 @@
|
||||
rewriter.create<LLVM::ExtractElementOp>(loc, __VA_ARGS__)
|
||||
#define load(...) rewriter.create<LLVM::LoadOp>(loc, __VA_ARGS__)
|
||||
#define store(val, ptr) rewriter.create<LLVM::StoreOp>(loc, val, ptr)
|
||||
#define fcmp_ogt(lhs, rhs) \
|
||||
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
|
||||
LLVM::FCmpPredicate::ogt, lhs, rhs)
|
||||
#define fcmp_olt(lhs, rhs) \
|
||||
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
|
||||
LLVM::FCmpPredicate::olt, lhs, rhs)
|
||||
#define icmp_eq(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__)
|
||||
#define icmp_ne(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, __VA_ARGS__)
|
||||
#define icmp_slt(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, __VA_ARGS__)
|
||||
#define icmp_sle(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sle, __VA_ARGS__)
|
||||
#define icmp_sgt(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, __VA_ARGS__)
|
||||
#define icmp_sge(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sge, __VA_ARGS__)
|
||||
#define icmp_ult(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ult, __VA_ARGS__)
|
||||
#define icmp_ule(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ule, __VA_ARGS__)
|
||||
#define icmp_ugt(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ugt, __VA_ARGS__)
|
||||
#define icmp_uge(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::uge, __VA_ARGS__)
|
||||
#define select(...) rewriter.create<LLVM::SelectOp>(loc, __VA_ARGS__)
|
||||
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
|
||||
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
|
||||
|
Reference in New Issue
Block a user