[Triton-MLIR][BACKEND] Add argmin / argmax implementation for ReduceOp (#918)
This commit is contained in:
@@ -20,8 +20,6 @@ SmallVector<unsigned>
|
|||||||
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||||
unsigned &outVec);
|
unsigned &outVec);
|
||||||
|
|
||||||
SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op);
|
|
||||||
|
|
||||||
} // namespace triton
|
} // namespace triton
|
||||||
|
|
||||||
/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h
|
/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h
|
||||||
|
@@ -26,6 +26,12 @@ public:
|
|||||||
|
|
||||||
unsigned getThreadsReductionAxis();
|
unsigned getThreadsReductionAxis();
|
||||||
|
|
||||||
|
SmallVector<unsigned> getScratchConfigBasic();
|
||||||
|
|
||||||
|
SmallVector<SmallVector<unsigned>> getScratchConfigsFast();
|
||||||
|
|
||||||
|
unsigned getScratchSizeInBytes();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
triton::ReduceOp op;
|
triton::ReduceOp op;
|
||||||
RankedTensorType srcTy{};
|
RankedTensorType srcTy{};
|
||||||
@@ -39,6 +45,14 @@ bool maybeAliasOp(Operation *op);
|
|||||||
|
|
||||||
std::string getValueOperandName(Value value, AsmState &state);
|
std::string getValueOperandName(Value value, AsmState &state);
|
||||||
|
|
||||||
|
template <typename T_OUT, typename T_IN>
|
||||||
|
inline SmallVector<T_OUT> convertType(ArrayRef<T_IN> in) {
|
||||||
|
SmallVector<T_OUT> out;
|
||||||
|
for (const T_IN &i : in)
|
||||||
|
out.push_back(T_OUT(i));
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename Int> Int product(llvm::ArrayRef<Int> arr) {
|
template <typename Int> Int product(llvm::ArrayRef<Int> arr) {
|
||||||
return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{});
|
return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{});
|
||||||
}
|
}
|
||||||
|
@@ -351,6 +351,11 @@ def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect,
|
|||||||
|
|
||||||
let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)";
|
let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)";
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
// This member function is marked static because we need to call it before the ReduceOp
|
||||||
|
// is constructed, see the implementation of create_reduce in triton.cc.
|
||||||
|
static bool withIndex(mlir::triton::RedOp redOp);
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@@ -88,25 +88,6 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
|||||||
return paddedRepShape;
|
return paddedRepShape;
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op) {
|
|
||||||
ReduceOpHelper helper(op);
|
|
||||||
|
|
||||||
SmallVector<unsigned> smemShape;
|
|
||||||
auto srcShape = helper.getSrcShape();
|
|
||||||
for (auto d : srcShape)
|
|
||||||
smemShape.push_back(d);
|
|
||||||
|
|
||||||
auto axis = op.axis();
|
|
||||||
if (helper.isFastReduction()) {
|
|
||||||
smemShape[axis] = helper.getInterWarpSize();
|
|
||||||
} else {
|
|
||||||
smemShape[axis] =
|
|
||||||
std::min(smemShape[axis], helper.getThreadsReductionAxis());
|
|
||||||
}
|
|
||||||
|
|
||||||
return smemShape;
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: extend beyond scalars
|
// TODO: extend beyond scalars
|
||||||
SmallVector<unsigned> getScratchConfigForAtomicRMW(triton::AtomicRMWOp op) {
|
SmallVector<unsigned> getScratchConfigForAtomicRMW(triton::AtomicRMWOp op) {
|
||||||
SmallVector<unsigned> smemShape;
|
SmallVector<unsigned> smemShape;
|
||||||
@@ -173,21 +154,9 @@ private:
|
|||||||
/// Initializes temporary shared memory for a given operation.
|
/// Initializes temporary shared memory for a given operation.
|
||||||
void getScratchValueSize(Operation *op) {
|
void getScratchValueSize(Operation *op) {
|
||||||
if (auto reduceOp = dyn_cast<triton::ReduceOp>(op)) {
|
if (auto reduceOp = dyn_cast<triton::ReduceOp>(op)) {
|
||||||
// TODO(Keren): Reduce with index is not supported yet.
|
ReduceOpHelper helper(reduceOp);
|
||||||
auto value = op->getOperand(0);
|
unsigned bytes = helper.getScratchSizeInBytes();
|
||||||
if (auto tensorType = value.getType().dyn_cast<RankedTensorType>()) {
|
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||||
bool fastReduce = ReduceOpHelper(reduceOp).isFastReduction();
|
|
||||||
auto smemShape = getScratchConfigForReduce(reduceOp);
|
|
||||||
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
|
|
||||||
std::multiplies{});
|
|
||||||
if (fastReduce) {
|
|
||||||
auto mod = op->getParentOfType<ModuleOp>();
|
|
||||||
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
|
||||||
elems = std::max<unsigned>(elems, numWarps * 32);
|
|
||||||
}
|
|
||||||
auto bytes = elems * tensorType.getElementTypeBitWidth() / 8;
|
|
||||||
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
|
||||||
}
|
|
||||||
} else if (auto cvtLayout = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
|
} else if (auto cvtLayout = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
|
||||||
auto srcTy = cvtLayout.src().getType().cast<RankedTensorType>();
|
auto srcTy = cvtLayout.src().getType().cast<RankedTensorType>();
|
||||||
auto dstTy = cvtLayout.result().getType().cast<RankedTensorType>();
|
auto dstTy = cvtLayout.result().getType().cast<RankedTensorType>();
|
||||||
|
@@ -37,6 +37,55 @@ unsigned ReduceOpHelper::getThreadsReductionAxis() {
|
|||||||
triton::gpu::getWarpsPerCTA(srcLayout)[axis];
|
triton::gpu::getWarpsPerCTA(srcLayout)[axis];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SmallVector<unsigned> ReduceOpHelper::getScratchConfigBasic() {
|
||||||
|
auto axis = op.axis();
|
||||||
|
auto smemShape = convertType<unsigned>(getSrcShape());
|
||||||
|
smemShape[axis] = std::min(smemShape[axis], getThreadsReductionAxis());
|
||||||
|
return smemShape;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
|
||||||
|
auto axis = op.axis();
|
||||||
|
SmallVector<SmallVector<unsigned>> smemShapes(3);
|
||||||
|
|
||||||
|
/// shared memory block0
|
||||||
|
smemShapes[0] = convertType<unsigned>(getSrcShape());
|
||||||
|
smemShapes[0][axis] = getInterWarpSize();
|
||||||
|
|
||||||
|
/// FIXME(Qingyi): This size is actually larger than required.
|
||||||
|
/// shared memory block1:
|
||||||
|
auto mod = op.getOperation()->getParentOfType<ModuleOp>();
|
||||||
|
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||||
|
smemShapes[1].push_back(numWarps * 32);
|
||||||
|
|
||||||
|
/// FIXME(Qingyi): This requirement is actually not necessary, because it is
|
||||||
|
/// always smaller than smemShapes[0] shared memory block2
|
||||||
|
smemShapes[2] = convertType<unsigned>(getSrcShape());
|
||||||
|
smemShapes[2].erase(smemShapes[2].begin() + axis);
|
||||||
|
|
||||||
|
return smemShapes;
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned ReduceOpHelper::getScratchSizeInBytes() {
|
||||||
|
unsigned elems = 0;
|
||||||
|
if (isFastReduction()) {
|
||||||
|
auto smemShapes = getScratchConfigsFast();
|
||||||
|
for (const auto &smemShape : smemShapes)
|
||||||
|
elems = std::max(elems, product<unsigned>(smemShape));
|
||||||
|
} else {
|
||||||
|
auto smemShape = getScratchConfigBasic();
|
||||||
|
elems = product<unsigned>(smemShape);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tensorType = op.operand().getType().cast<RankedTensorType>();
|
||||||
|
unsigned bytes = elems * tensorType.getElementTypeBitWidth() / 8;
|
||||||
|
|
||||||
|
if (triton::ReduceOp::withIndex(op.redOp()))
|
||||||
|
bytes += elems * sizeof(int32_t);
|
||||||
|
|
||||||
|
return bytes;
|
||||||
|
}
|
||||||
|
|
||||||
bool isSharedEncoding(Value value) {
|
bool isSharedEncoding(Value value) {
|
||||||
auto type = value.getType();
|
auto type = value.getType();
|
||||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||||
|
@@ -1338,6 +1338,10 @@ private:
|
|||||||
void accumulate(ConversionPatternRewriter &rewriter, Location loc,
|
void accumulate(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
RedOp redOp, Value &acc, Value cur, bool isFirst) const;
|
RedOp redOp, Value &acc, Value cur, bool isFirst) const;
|
||||||
|
|
||||||
|
void accumulateWithIndex(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
|
RedOp redOp, Value &acc, Value &accIndex, Value cur,
|
||||||
|
Value curIndex, bool isFirst) const;
|
||||||
|
|
||||||
Value shflSync(ConversionPatternRewriter &rewriter, Location loc, Value val,
|
Value shflSync(ConversionPatternRewriter &rewriter, Location loc, Value val,
|
||||||
int i) const;
|
int i) const;
|
||||||
|
|
||||||
@@ -1366,7 +1370,6 @@ void ReduceOpConversion::accumulate(ConversionPatternRewriter &rewriter,
|
|||||||
acc = cur;
|
acc = cur;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto type = cur.getType();
|
|
||||||
switch (redOp) {
|
switch (redOp) {
|
||||||
case RedOp::ADD:
|
case RedOp::ADD:
|
||||||
acc = add(acc, cur);
|
acc = add(acc, cur);
|
||||||
@@ -1395,6 +1398,75 @@ void ReduceOpConversion::accumulate(ConversionPatternRewriter &rewriter,
|
|||||||
case RedOp::XOR:
|
case RedOp::XOR:
|
||||||
acc = xor_(acc, cur);
|
acc = xor_(acc, cur);
|
||||||
break;
|
break;
|
||||||
|
case RedOp::ARGMIN:
|
||||||
|
case RedOp::ARGMAX:
|
||||||
|
case RedOp::ARGUMIN:
|
||||||
|
case RedOp::ARGUMAX:
|
||||||
|
case RedOp::ARGFMIN:
|
||||||
|
case RedOp::ARGFMAX:
|
||||||
|
llvm::report_fatal_error(
|
||||||
|
"This accumulate implementation is not for argmin / argmax");
|
||||||
|
default:
|
||||||
|
llvm::report_fatal_error("Unsupported reduce op");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReduceOpConversion::accumulateWithIndex(
|
||||||
|
ConversionPatternRewriter &rewriter, Location loc, RedOp redOp, Value &acc,
|
||||||
|
Value &accIndex, Value cur, Value curIndex, bool isFirst) const {
|
||||||
|
if (isFirst) {
|
||||||
|
acc = cur;
|
||||||
|
accIndex = curIndex;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
switch (redOp) {
|
||||||
|
case RedOp::ARGMIN:
|
||||||
|
accIndex =
|
||||||
|
select(icmp_slt(acc, cur), accIndex,
|
||||||
|
select(icmp_sgt(acc, cur), curIndex, smin(accIndex, curIndex)));
|
||||||
|
acc = smin(acc, cur);
|
||||||
|
break;
|
||||||
|
case RedOp::ARGMAX:
|
||||||
|
accIndex =
|
||||||
|
select(icmp_sgt(acc, cur), accIndex,
|
||||||
|
select(icmp_slt(acc, cur), curIndex, smin(accIndex, curIndex)));
|
||||||
|
acc = smax(acc, cur);
|
||||||
|
break;
|
||||||
|
case RedOp::ARGUMIN:
|
||||||
|
accIndex =
|
||||||
|
select(icmp_ult(acc, cur), accIndex,
|
||||||
|
select(icmp_ugt(acc, cur), curIndex, smin(accIndex, curIndex)));
|
||||||
|
acc = umin(acc, cur);
|
||||||
|
break;
|
||||||
|
case RedOp::ARGUMAX:
|
||||||
|
accIndex =
|
||||||
|
select(icmp_ugt(acc, cur), accIndex,
|
||||||
|
select(icmp_ult(acc, cur), curIndex, smin(accIndex, curIndex)));
|
||||||
|
acc = umax(acc, cur);
|
||||||
|
break;
|
||||||
|
case RedOp::ARGFMIN:
|
||||||
|
accIndex =
|
||||||
|
select(fcmp_olt(acc, cur), accIndex,
|
||||||
|
select(fcmp_ogt(acc, cur), curIndex, smin(accIndex, curIndex)));
|
||||||
|
acc = fmin(acc, cur);
|
||||||
|
break;
|
||||||
|
case RedOp::ARGFMAX:
|
||||||
|
accIndex =
|
||||||
|
select(fcmp_ogt(acc, cur), accIndex,
|
||||||
|
select(fcmp_olt(acc, cur), curIndex, smin(accIndex, curIndex)));
|
||||||
|
acc = fmax(acc, cur);
|
||||||
|
break;
|
||||||
|
case RedOp::ADD:
|
||||||
|
case RedOp::FADD:
|
||||||
|
case RedOp::MIN:
|
||||||
|
case RedOp::MAX:
|
||||||
|
case RedOp::UMIN:
|
||||||
|
case RedOp::UMAX:
|
||||||
|
case RedOp::FMIN:
|
||||||
|
case RedOp::FMAX:
|
||||||
|
case RedOp::XOR:
|
||||||
|
llvm::report_fatal_error(
|
||||||
|
"This accumulate implementation is only for argmin / argmax");
|
||||||
default:
|
default:
|
||||||
llvm::report_fatal_error("Unsupported reduce op");
|
llvm::report_fatal_error("Unsupported reduce op");
|
||||||
}
|
}
|
||||||
@@ -1433,6 +1505,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
|||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
unsigned axis = op.axis();
|
unsigned axis = op.axis();
|
||||||
|
bool withIndex = triton::ReduceOp::withIndex(op.redOp());
|
||||||
|
|
||||||
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
||||||
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||||
@@ -1440,11 +1513,17 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
|||||||
auto srcShape = srcTy.getShape();
|
auto srcShape = srcTy.getShape();
|
||||||
|
|
||||||
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
|
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
|
||||||
|
auto llvmIndexTy = getTypeConverter()->getIndexType();
|
||||||
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||||
|
auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3);
|
||||||
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||||
smemBase = bitcast(smemBase, elemPtrTy);
|
smemBase = bitcast(smemBase, elemPtrTy);
|
||||||
|
|
||||||
auto smemShape = getScratchConfigForReduce(op);
|
ReduceOpHelper helper(op);
|
||||||
|
auto smemShape = helper.getScratchConfigBasic();
|
||||||
|
unsigned elems = product<unsigned>(smemShape);
|
||||||
|
Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(elems));
|
||||||
|
indexSmemBase = bitcast(indexSmemBase, indexPtrTy);
|
||||||
|
|
||||||
unsigned srcElems = getElemsPerThread(srcTy);
|
unsigned srcElems = getElemsPerThread(srcTy);
|
||||||
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
|
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
|
||||||
@@ -1454,6 +1533,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
|||||||
emitOffsetForBlockedLayout(srcLayout, srcShape);
|
emitOffsetForBlockedLayout(srcLayout, srcShape);
|
||||||
|
|
||||||
std::map<SmallVector<unsigned>, Value> accs;
|
std::map<SmallVector<unsigned>, Value> accs;
|
||||||
|
std::map<SmallVector<unsigned>, Value> accIndices;
|
||||||
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
|
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
|
||||||
|
|
||||||
// reduce within threads
|
// reduce within threads
|
||||||
@@ -1461,7 +1541,13 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
|||||||
SmallVector<unsigned> key = offset[i];
|
SmallVector<unsigned> key = offset[i];
|
||||||
key[axis] = 0;
|
key[axis] = 0;
|
||||||
bool isFirst = accs.find(key) == accs.end();
|
bool isFirst = accs.find(key) == accs.end();
|
||||||
accumulate(rewriter, loc, op.redOp(), accs[key], srcValues[i], isFirst);
|
if (!withIndex) {
|
||||||
|
accumulate(rewriter, loc, op.redOp(), accs[key], srcValues[i], isFirst);
|
||||||
|
} else {
|
||||||
|
Value curIndex = srcIndices[i][axis];
|
||||||
|
accumulateWithIndex(rewriter, loc, op.redOp(), accs[key], accIndices[key],
|
||||||
|
srcValues[i], curIndex, isFirst);
|
||||||
|
}
|
||||||
if (isFirst)
|
if (isFirst)
|
||||||
indices[key] = srcIndices[i];
|
indices[key] = srcIndices[i];
|
||||||
}
|
}
|
||||||
@@ -1477,12 +1563,18 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
|||||||
for (auto it : accs) {
|
for (auto it : accs) {
|
||||||
const SmallVector<unsigned> &key = it.first;
|
const SmallVector<unsigned> &key = it.first;
|
||||||
Value acc = it.second;
|
Value acc = it.second;
|
||||||
|
Value accIndex;
|
||||||
|
if (withIndex)
|
||||||
|
accIndex = accIndices[key];
|
||||||
SmallVector<Value> writeIdx = indices[key];
|
SmallVector<Value> writeIdx = indices[key];
|
||||||
|
|
||||||
writeIdx[axis] = udiv(writeIdx[axis], sizePerThread);
|
writeIdx[axis] = udiv(writeIdx[axis], sizePerThread);
|
||||||
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, srcOrd);
|
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, srcOrd);
|
||||||
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
||||||
|
Value indexWritePtr = gep(indexPtrTy, indexSmemBase, writeOffset);
|
||||||
store(acc, writePtr);
|
store(acc, writePtr);
|
||||||
|
if (withIndex)
|
||||||
|
store(accIndex, indexWritePtr);
|
||||||
|
|
||||||
SmallVector<Value> readIdx(writeIdx.size(), ints[0]);
|
SmallVector<Value> readIdx(writeIdx.size(), ints[0]);
|
||||||
for (int N = smemShape[axis] / 2; N > 0; N >>= 1) {
|
for (int N = smemShape[axis] / 2; N > 0; N >>= 1) {
|
||||||
@@ -1493,11 +1585,24 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
|||||||
ints[0]);
|
ints[0]);
|
||||||
Value readPtr = gep(elemPtrTy, writePtr, readOffset);
|
Value readPtr = gep(elemPtrTy, writePtr, readOffset);
|
||||||
barrier();
|
barrier();
|
||||||
accumulate(rewriter, loc, op.redOp(), acc, load(readPtr), false);
|
if (!withIndex) {
|
||||||
store(acc, writePtr);
|
Value cur = load(readPtr);
|
||||||
|
accumulate(rewriter, loc, op.redOp(), acc, cur, false);
|
||||||
|
store(acc, writePtr);
|
||||||
|
} else {
|
||||||
|
Value cur = load(readPtr);
|
||||||
|
Value indexReadPtr = gep(indexPtrTy, indexWritePtr, readOffset);
|
||||||
|
Value curIndex = load(indexReadPtr);
|
||||||
|
accumulateWithIndex(rewriter, loc, op.redOp(), acc, accIndex, cur,
|
||||||
|
curIndex, false);
|
||||||
|
store(acc, writePtr);
|
||||||
|
store(accIndex, indexWritePtr);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
|
||||||
// set output values
|
// set output values
|
||||||
if (auto resultTy = op.getType().dyn_cast<RankedTensorType>()) {
|
if (auto resultTy = op.getType().dyn_cast<RankedTensorType>()) {
|
||||||
// nd-tensor where n >= 1
|
// nd-tensor where n >= 1
|
||||||
@@ -1508,25 +1613,25 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
|||||||
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultShape);
|
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultShape);
|
||||||
assert(resultIndices.size() == resultElems);
|
assert(resultIndices.size() == resultElems);
|
||||||
|
|
||||||
barrier();
|
|
||||||
SmallVector<Value> resultVals(resultElems);
|
SmallVector<Value> resultVals(resultElems);
|
||||||
for (unsigned i = 0; i < resultElems; ++i) {
|
for (unsigned i = 0; i < resultElems; ++i) {
|
||||||
SmallVector<Value> readIdx = resultIndices[i];
|
SmallVector<Value> readIdx = resultIndices[i];
|
||||||
readIdx.insert(readIdx.begin() + axis, ints[0]);
|
readIdx.insert(readIdx.begin() + axis, ints[0]);
|
||||||
Value readOffset = linearize(rewriter, loc, readIdx, smemShape, srcOrd);
|
Value readOffset = linearize(rewriter, loc, readIdx, smemShape, srcOrd);
|
||||||
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
||||||
resultVals[i] = load(readPtr);
|
Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset);
|
||||||
|
resultVals[i] = withIndex ? load(indexReadPtr) : load(readPtr);
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Type> resultTypes(resultElems, llvmElemTy);
|
SmallVector<Type> resultTypes(resultElems,
|
||||||
|
withIndex ? llvmIndexTy : llvmElemTy);
|
||||||
Type structTy =
|
Type structTy =
|
||||||
LLVM::LLVMStructType::getLiteral(this->getContext(), resultTypes);
|
LLVM::LLVMStructType::getLiteral(this->getContext(), resultTypes);
|
||||||
Value ret = getStructFromElements(loc, resultVals, rewriter, structTy);
|
Value ret = getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||||
rewriter.replaceOp(op, ret);
|
rewriter.replaceOp(op, ret);
|
||||||
} else {
|
} else {
|
||||||
// 0d-tensor -> scalar
|
// 0d-tensor -> scalar
|
||||||
barrier();
|
Value resultVal = withIndex ? load(indexSmemBase) : load(smemBase);
|
||||||
Value resultVal = load(smemBase);
|
|
||||||
rewriter.replaceOp(op, resultVal);
|
rewriter.replaceOp(op, resultVal);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1538,25 +1643,35 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
|||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
unsigned axis = adaptor.axis();
|
unsigned axis = adaptor.axis();
|
||||||
|
bool withIndex = triton::ReduceOp::withIndex(op.redOp());
|
||||||
|
|
||||||
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
||||||
auto srcLayout = srcTy.getEncoding();
|
auto srcLayout = srcTy.getEncoding();
|
||||||
auto srcShape = srcTy.getShape();
|
auto srcShape = srcTy.getShape();
|
||||||
auto srcRank = srcTy.getRank();
|
auto srcRank = srcTy.getRank();
|
||||||
|
auto order = getOrder(srcLayout);
|
||||||
|
|
||||||
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout);
|
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout);
|
||||||
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcLayout);
|
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcLayout);
|
||||||
|
|
||||||
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
|
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
|
||||||
|
auto llvmIndexTy = getTypeConverter()->getIndexType();
|
||||||
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||||
|
auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3);
|
||||||
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||||
smemBase = bitcast(smemBase, elemPtrTy);
|
smemBase = bitcast(smemBase, elemPtrTy);
|
||||||
|
|
||||||
ReduceOpHelper helper(op);
|
ReduceOpHelper helper(op);
|
||||||
|
auto smemShapes = helper.getScratchConfigsFast();
|
||||||
|
unsigned elems = product<unsigned>(smemShapes[0]);
|
||||||
|
unsigned maxElems = std::max(elems, product<unsigned>(smemShapes[1]));
|
||||||
|
maxElems = std::max(maxElems, product<unsigned>(smemShapes[2]));
|
||||||
|
Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(maxElems));
|
||||||
|
indexSmemBase = bitcast(indexSmemBase, indexPtrTy);
|
||||||
|
|
||||||
unsigned sizeIntraWarps = helper.getIntraWarpSize();
|
unsigned sizeIntraWarps = helper.getIntraWarpSize();
|
||||||
unsigned sizeInterWarps = helper.getInterWarpSize();
|
unsigned sizeInterWarps = helper.getInterWarpSize();
|
||||||
|
|
||||||
auto order = getOrder(srcLayout);
|
|
||||||
unsigned srcElems = getElemsPerThread(srcTy);
|
unsigned srcElems = getElemsPerThread(srcTy);
|
||||||
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
|
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
|
||||||
auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter);
|
auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter);
|
||||||
@@ -1565,16 +1680,21 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
|||||||
emitOffsetForLayout(srcLayout, srcShape);
|
emitOffsetForLayout(srcLayout, srcShape);
|
||||||
|
|
||||||
std::map<SmallVector<unsigned>, Value> accs;
|
std::map<SmallVector<unsigned>, Value> accs;
|
||||||
|
std::map<SmallVector<unsigned>, Value> accIndices;
|
||||||
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
|
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
|
||||||
|
|
||||||
auto smemShape = getScratchConfigForReduce(op);
|
|
||||||
|
|
||||||
// reduce within threads
|
// reduce within threads
|
||||||
for (unsigned i = 0; i < srcElems; ++i) {
|
for (unsigned i = 0; i < srcElems; ++i) {
|
||||||
SmallVector<unsigned> key = offset[i];
|
SmallVector<unsigned> key = offset[i];
|
||||||
key[axis] = 0;
|
key[axis] = 0;
|
||||||
bool isFirst = accs.find(key) == accs.end();
|
bool isFirst = accs.find(key) == accs.end();
|
||||||
accumulate(rewriter, loc, op.redOp(), accs[key], srcValues[i], isFirst);
|
if (!withIndex) {
|
||||||
|
accumulate(rewriter, loc, op.redOp(), accs[key], srcValues[i], isFirst);
|
||||||
|
} else {
|
||||||
|
Value curIndex = srcIndices[i][axis];
|
||||||
|
accumulateWithIndex(rewriter, loc, op.redOp(), accs[key], accIndices[key],
|
||||||
|
srcValues[i], curIndex, isFirst);
|
||||||
|
}
|
||||||
if (isFirst)
|
if (isFirst)
|
||||||
indices[key] = srcIndices[i];
|
indices[key] = srcIndices[i];
|
||||||
}
|
}
|
||||||
@@ -1599,18 +1719,32 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
|||||||
for (auto it : accs) {
|
for (auto it : accs) {
|
||||||
const SmallVector<unsigned> &key = it.first;
|
const SmallVector<unsigned> &key = it.first;
|
||||||
Value acc = it.second;
|
Value acc = it.second;
|
||||||
|
Value accIndex;
|
||||||
|
if (withIndex)
|
||||||
|
accIndex = accIndices[key];
|
||||||
|
|
||||||
// reduce within warps
|
// reduce within warps
|
||||||
for (unsigned N = sizeIntraWarps / 2; N > 0; N >>= 1) {
|
for (unsigned N = sizeIntraWarps / 2; N > 0; N >>= 1) {
|
||||||
Value shfl = shflSync(rewriter, loc, acc, N);
|
Value shfl = shflSync(rewriter, loc, acc, N);
|
||||||
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
|
if (!withIndex) {
|
||||||
|
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
|
||||||
|
} else {
|
||||||
|
Value shflIndex = shflSync(rewriter, loc, accIndex, N);
|
||||||
|
accumulateWithIndex(rewriter, loc, op.redOp(), acc, accIndex, shfl,
|
||||||
|
shflIndex, false);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Value> writeIdx = indices[key];
|
SmallVector<Value> writeIdx = indices[key];
|
||||||
writeIdx[axis] = (sizeInterWarps == 1) ? zero : warpIdAxis;
|
writeIdx[axis] = (sizeInterWarps == 1) ? zero : warpIdAxis;
|
||||||
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, order);
|
Value writeOffset =
|
||||||
|
linearize(rewriter, loc, writeIdx, smemShapes[0], order);
|
||||||
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
||||||
storeShared(rewriter, loc, writePtr, acc, laneZero);
|
storeShared(rewriter, loc, writePtr, acc, laneZero);
|
||||||
|
if (withIndex) {
|
||||||
|
Value indexWritePtr = gep(indexPtrTy, indexSmemBase, writeOffset);
|
||||||
|
storeShared(rewriter, loc, indexWritePtr, accIndex, laneZero);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
barrier();
|
barrier();
|
||||||
@@ -1622,7 +1756,6 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
|||||||
//
|
//
|
||||||
// each thread needs to process:
|
// each thread needs to process:
|
||||||
// elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads
|
// elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads
|
||||||
unsigned elems = product<unsigned>(smemShape);
|
|
||||||
unsigned numThreads =
|
unsigned numThreads =
|
||||||
product<unsigned>(triton::gpu::getWarpsPerCTA(srcLayout)) * 32;
|
product<unsigned>(triton::gpu::getWarpsPerCTA(srcLayout)) * 32;
|
||||||
unsigned elemsPerThread = std::max<unsigned>(elems / numThreads, 1);
|
unsigned elemsPerThread = std::max<unsigned>(elems / numThreads, 1);
|
||||||
@@ -1630,10 +1763,21 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
|||||||
for (unsigned round = 0; round < elemsPerThread; ++round) {
|
for (unsigned round = 0; round < elemsPerThread; ++round) {
|
||||||
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
||||||
Value acc = load(readPtr);
|
Value acc = load(readPtr);
|
||||||
|
Value accIndex;
|
||||||
|
if (withIndex) {
|
||||||
|
Value readIndexPtr = gep(indexPtrTy, indexSmemBase, readOffset);
|
||||||
|
accIndex = load(readIndexPtr);
|
||||||
|
}
|
||||||
|
|
||||||
for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) {
|
for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) {
|
||||||
Value shfl = shflSync(rewriter, loc, acc, N);
|
Value shfl = shflSync(rewriter, loc, acc, N);
|
||||||
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
|
if (!withIndex) {
|
||||||
|
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
|
||||||
|
} else {
|
||||||
|
Value shflIndex = shflSync(rewriter, loc, accIndex, N);
|
||||||
|
accumulateWithIndex(rewriter, loc, op.redOp(), acc, accIndex, shfl,
|
||||||
|
shflIndex, false);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Value writeOffset = udiv(readOffset, i32_val(sizeInterWarps));
|
Value writeOffset = udiv(readOffset, i32_val(sizeInterWarps));
|
||||||
@@ -1642,8 +1786,12 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
|||||||
Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps));
|
Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps));
|
||||||
Value laneIdModSizeInterWarpsIsZero =
|
Value laneIdModSizeInterWarpsIsZero =
|
||||||
icmp_eq(laneIdModSizeInterWarps, zero);
|
icmp_eq(laneIdModSizeInterWarps, zero);
|
||||||
storeShared(rewriter, loc, writePtr, acc,
|
Value pred = and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero);
|
||||||
and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero));
|
storeShared(rewriter, loc, writePtr, acc, pred);
|
||||||
|
if (withIndex) {
|
||||||
|
Value writeIndexPtr = gep(indexPtrTy, indexSmemBase, writeOffset);
|
||||||
|
storeShared(rewriter, loc, writeIndexPtr, accIndex, pred);
|
||||||
|
}
|
||||||
|
|
||||||
if (round != elemsPerThread - 1) {
|
if (round != elemsPerThread - 1) {
|
||||||
readOffset = add(readOffset, i32_val(numThreads));
|
readOffset = add(readOffset, i32_val(numThreads));
|
||||||
@@ -1671,25 +1819,24 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
|||||||
assert(resultIndices.size() == resultElems);
|
assert(resultIndices.size() == resultElems);
|
||||||
|
|
||||||
SmallVector<Value> resultVals(resultElems);
|
SmallVector<Value> resultVals(resultElems);
|
||||||
SmallVector<unsigned> resultShape;
|
|
||||||
std::copy(resultTy.getShape().begin(), resultTy.getShape().end(),
|
|
||||||
std::back_inserter(resultShape));
|
|
||||||
for (size_t i = 0; i < resultElems; ++i) {
|
for (size_t i = 0; i < resultElems; ++i) {
|
||||||
SmallVector<Value> readIdx = resultIndices[i];
|
SmallVector<Value> readIdx = resultIndices[i];
|
||||||
Value readOffset =
|
Value readOffset =
|
||||||
linearize(rewriter, loc, readIdx, resultShape, resultOrd);
|
linearize(rewriter, loc, readIdx, smemShapes[2], resultOrd);
|
||||||
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
||||||
resultVals[i] = load(readPtr);
|
Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset);
|
||||||
|
resultVals[i] = withIndex ? load(indexReadPtr) : load(readPtr);
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Type> resultTypes(resultElems, llvmElemTy);
|
SmallVector<Type> resultTypes(resultElems,
|
||||||
|
withIndex ? llvmIndexTy : llvmElemTy);
|
||||||
Type structTy =
|
Type structTy =
|
||||||
LLVM::LLVMStructType::getLiteral(this->getContext(), resultTypes);
|
LLVM::LLVMStructType::getLiteral(this->getContext(), resultTypes);
|
||||||
Value ret = getStructFromElements(loc, resultVals, rewriter, structTy);
|
Value ret = getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||||
rewriter.replaceOp(op, ret);
|
rewriter.replaceOp(op, ret);
|
||||||
} else {
|
} else {
|
||||||
// 0d-tensor -> scalar
|
// 0d-tensor -> scalar
|
||||||
Value resultVal = load(smemBase);
|
Value resultVal = withIndex ? load(indexSmemBase) : load(smemBase);
|
||||||
rewriter.replaceOp(op, resultVal);
|
rewriter.replaceOp(op, resultVal);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -60,12 +60,32 @@
|
|||||||
rewriter.create<LLVM::ExtractElementOp>(loc, __VA_ARGS__)
|
rewriter.create<LLVM::ExtractElementOp>(loc, __VA_ARGS__)
|
||||||
#define load(...) rewriter.create<LLVM::LoadOp>(loc, __VA_ARGS__)
|
#define load(...) rewriter.create<LLVM::LoadOp>(loc, __VA_ARGS__)
|
||||||
#define store(val, ptr) rewriter.create<LLVM::StoreOp>(loc, val, ptr)
|
#define store(val, ptr) rewriter.create<LLVM::StoreOp>(loc, val, ptr)
|
||||||
|
#define fcmp_ogt(lhs, rhs) \
|
||||||
|
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
|
||||||
|
LLVM::FCmpPredicate::ogt, lhs, rhs)
|
||||||
|
#define fcmp_olt(lhs, rhs) \
|
||||||
|
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
|
||||||
|
LLVM::FCmpPredicate::olt, lhs, rhs)
|
||||||
#define icmp_eq(...) \
|
#define icmp_eq(...) \
|
||||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__)
|
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__)
|
||||||
#define icmp_ne(...) \
|
#define icmp_ne(...) \
|
||||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, __VA_ARGS__)
|
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, __VA_ARGS__)
|
||||||
#define icmp_slt(...) \
|
#define icmp_slt(...) \
|
||||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, __VA_ARGS__)
|
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, __VA_ARGS__)
|
||||||
|
#define icmp_sle(...) \
|
||||||
|
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sle, __VA_ARGS__)
|
||||||
|
#define icmp_sgt(...) \
|
||||||
|
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, __VA_ARGS__)
|
||||||
|
#define icmp_sge(...) \
|
||||||
|
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sge, __VA_ARGS__)
|
||||||
|
#define icmp_ult(...) \
|
||||||
|
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ult, __VA_ARGS__)
|
||||||
|
#define icmp_ule(...) \
|
||||||
|
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ule, __VA_ARGS__)
|
||||||
|
#define icmp_ugt(...) \
|
||||||
|
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ugt, __VA_ARGS__)
|
||||||
|
#define icmp_uge(...) \
|
||||||
|
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::uge, __VA_ARGS__)
|
||||||
#define select(...) rewriter.create<LLVM::SelectOp>(loc, __VA_ARGS__)
|
#define select(...) rewriter.create<LLVM::SelectOp>(loc, __VA_ARGS__)
|
||||||
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
|
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
|
||||||
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
|
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
|
||||||
|
@@ -240,12 +240,16 @@ mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
|
|||||||
Value arg = operands[0];
|
Value arg = operands[0];
|
||||||
auto argTy = arg.getType().cast<RankedTensorType>();
|
auto argTy = arg.getType().cast<RankedTensorType>();
|
||||||
auto argEltTy = argTy.getElementType();
|
auto argEltTy = argTy.getElementType();
|
||||||
|
auto i32Ty = IntegerType::get(argEltTy.getContext(), 32);
|
||||||
|
auto redOp = attributes.get("redOp").cast<mlir::triton::RedOpAttr>().getValue();
|
||||||
|
bool withIndex = mlir::triton::ReduceOp::withIndex(redOp);
|
||||||
|
auto retEltTy = withIndex ? i32Ty : argEltTy;
|
||||||
auto retShape = argTy.getShape().vec();
|
auto retShape = argTy.getShape().vec();
|
||||||
int axis = attributes.get("axis").cast<IntegerAttr>().getInt();
|
int axis = attributes.get("axis").cast<IntegerAttr>().getInt();
|
||||||
retShape.erase(retShape.begin() + axis);
|
retShape.erase(retShape.begin() + axis);
|
||||||
if (retShape.empty()) {
|
if (retShape.empty()) {
|
||||||
// 0d-tensor -> scalar
|
// 0d-tensor -> scalar
|
||||||
inferredReturnTypes.push_back(argEltTy);
|
inferredReturnTypes.push_back(retEltTy);
|
||||||
} else {
|
} else {
|
||||||
// nd-tensor where n >= 1
|
// nd-tensor where n >= 1
|
||||||
// infer encoding
|
// infer encoding
|
||||||
@@ -264,11 +268,20 @@ mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
|
|||||||
}
|
}
|
||||||
// create type
|
// create type
|
||||||
inferredReturnTypes.push_back(
|
inferredReturnTypes.push_back(
|
||||||
RankedTensorType::get(retShape, argEltTy, retEncoding));
|
RankedTensorType::get(retShape, retEltTy, retEncoding));
|
||||||
}
|
}
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool mlir::triton::ReduceOp::withIndex(mlir::triton::RedOp redOp) {
|
||||||
|
return redOp == mlir::triton::RedOp::ARGMIN ||
|
||||||
|
redOp == mlir::triton::RedOp::ARGMAX ||
|
||||||
|
redOp == mlir::triton::RedOp::ARGUMIN ||
|
||||||
|
redOp == mlir::triton::RedOp::ARGUMAX ||
|
||||||
|
redOp == mlir::triton::RedOp::ARGFMIN ||
|
||||||
|
redOp == mlir::triton::RedOp::ARGFMAX;
|
||||||
|
}
|
||||||
|
|
||||||
//-- SplatOp --
|
//-- SplatOp --
|
||||||
OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
|
||||||
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
|
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
|
||||||
|
@@ -1195,10 +1195,11 @@ void init_triton_ir(py::module &&m) {
|
|||||||
operand.getType().dyn_cast<mlir::RankedTensorType>();
|
operand.getType().dyn_cast<mlir::RankedTensorType>();
|
||||||
std::vector<int64_t> shape = inputTensorType.getShape();
|
std::vector<int64_t> shape = inputTensorType.getShape();
|
||||||
shape.erase(shape.begin() + axis);
|
shape.erase(shape.begin() + axis);
|
||||||
mlir::Type resType = inputTensorType.getElementType();
|
bool withIndex = mlir::triton::ReduceOp::withIndex(redOp);
|
||||||
|
mlir::Type resType = withIndex ? self.getI32Type()
|
||||||
|
: inputTensorType.getElementType();
|
||||||
if (!shape.empty()) {
|
if (!shape.empty()) {
|
||||||
resType = mlir::RankedTensorType::get(
|
resType = mlir::RankedTensorType::get(shape, resType);
|
||||||
shape, inputTensorType.getElementType());
|
|
||||||
}
|
}
|
||||||
return self.create<mlir::triton::ReduceOp>(loc, resType, redOp,
|
return self.create<mlir::triton::ReduceOp>(loc, resType, redOp,
|
||||||
operand, axis);
|
operand, axis);
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.testing import assert_close
|
from torch.testing import assert_close
|
||||||
|
|
||||||
@@ -13,7 +14,9 @@ dtypes_with_bfloat16 = int_dtypes + uint_dtypes + float_dtypes
|
|||||||
dtype_mapping = {dtype_str: torch.__dict__[dtype_str] for dtype_str in dtypes}
|
dtype_mapping = {dtype_str: torch.__dict__[dtype_str] for dtype_str in dtypes}
|
||||||
|
|
||||||
|
|
||||||
def get_reduced_dtype(dtype):
|
def get_reduced_dtype(op, dtype):
|
||||||
|
if op in ['argmin', 'argmax']:
|
||||||
|
return torch.int32
|
||||||
if dtype in [torch.int8, torch.int16, torch.uint8]:
|
if dtype in [torch.int8, torch.int16, torch.uint8]:
|
||||||
return torch.int32
|
return torch.int32
|
||||||
if dtype in [torch.bfloat16]:
|
if dtype in [torch.bfloat16]:
|
||||||
@@ -48,7 +51,7 @@ def reduce2d_kernel(x_ptr, z_ptr, axis: tl.constexpr, block_m: tl.constexpr, blo
|
|||||||
|
|
||||||
reduce1d_configs = [
|
reduce1d_configs = [
|
||||||
(op, dtype, shape)
|
(op, dtype, shape)
|
||||||
for op in ['sum', 'min', 'max']
|
for op in ['sum', 'min', 'max', 'argmin', 'argmax', 'xor_sum']
|
||||||
for dtype in dtypes
|
for dtype in dtypes
|
||||||
for shape in [4, 8, 16, 32, 64, 128, 512, 1024]
|
for shape in [4, 8, 16, 32, 64, 128, 512, 1024]
|
||||||
]
|
]
|
||||||
@@ -56,8 +59,11 @@ reduce1d_configs = [
|
|||||||
|
|
||||||
@pytest.mark.parametrize('op, dtype, shape', reduce1d_configs)
|
@pytest.mark.parametrize('op, dtype, shape', reduce1d_configs)
|
||||||
def test_reduce1d(op, dtype, shape):
|
def test_reduce1d(op, dtype, shape):
|
||||||
|
if op == 'xor_sum' and dtype in float_dtypes:
|
||||||
|
return
|
||||||
|
|
||||||
dtype = dtype_mapping[dtype]
|
dtype = dtype_mapping[dtype]
|
||||||
reduced_dtype = get_reduced_dtype(dtype)
|
reduced_dtype = get_reduced_dtype(op, dtype)
|
||||||
|
|
||||||
if dtype.is_floating_point:
|
if dtype.is_floating_point:
|
||||||
x = torch.randn((shape,), device='cuda', dtype=dtype)
|
x = torch.randn((shape,), device='cuda', dtype=dtype)
|
||||||
@@ -79,8 +85,17 @@ def test_reduce1d(op, dtype, shape):
|
|||||||
golden_z = torch.sum(x, dtype=reduced_dtype)
|
golden_z = torch.sum(x, dtype=reduced_dtype)
|
||||||
elif op == 'min':
|
elif op == 'min':
|
||||||
golden_z = torch.min(x).to(reduced_dtype)
|
golden_z = torch.min(x).to(reduced_dtype)
|
||||||
else:
|
elif op == 'max':
|
||||||
golden_z = torch.max(x).to(reduced_dtype)
|
golden_z = torch.max(x).to(reduced_dtype)
|
||||||
|
elif op == 'argmin':
|
||||||
|
golden_z = torch.argmin(x).to(reduced_dtype)
|
||||||
|
elif op == 'argmax':
|
||||||
|
golden_z = torch.argmax(x).to(reduced_dtype)
|
||||||
|
elif op == 'xor_sum':
|
||||||
|
sum_npy = np.bitwise_xor.reduce(x.cpu().numpy())
|
||||||
|
golden_z = torch.tensor(sum_npy, dtype=reduced_dtype).cuda()
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f'Unknwon reduce op {op}')
|
||||||
|
|
||||||
if dtype.is_floating_point and op == 'sum':
|
if dtype.is_floating_point and op == 'sum':
|
||||||
if shape >= 256:
|
if shape >= 256:
|
||||||
@@ -95,7 +110,7 @@ def test_reduce1d(op, dtype, shape):
|
|||||||
|
|
||||||
reduce2d_configs = [
|
reduce2d_configs = [
|
||||||
(op, dtype, shape, axis)
|
(op, dtype, shape, axis)
|
||||||
for op in ['sum', 'min', 'max']
|
for op in ['sum', 'min', 'max', 'argmin', 'argmax', 'xor_sum']
|
||||||
for dtype in dtypes
|
for dtype in dtypes
|
||||||
for shape in [(1, 4), (1, 8), (1, 16), (1, 32), (2, 32), (4, 32), (4, 128), (32, 64)]
|
for shape in [(1, 4), (1, 8), (1, 16), (1, 32), (2, 32), (4, 32), (4, 128), (32, 64)]
|
||||||
for axis in [0, 1]
|
for axis in [0, 1]
|
||||||
@@ -104,8 +119,11 @@ reduce2d_configs = [
|
|||||||
|
|
||||||
@pytest.mark.parametrize('op, dtype, shape, axis', reduce2d_configs)
|
@pytest.mark.parametrize('op, dtype, shape, axis', reduce2d_configs)
|
||||||
def test_reduce2d(op, dtype, shape, axis):
|
def test_reduce2d(op, dtype, shape, axis):
|
||||||
|
if op == 'xor_sum' and dtype in float_dtypes:
|
||||||
|
return
|
||||||
|
|
||||||
dtype = dtype_mapping[dtype]
|
dtype = dtype_mapping[dtype]
|
||||||
reduced_dtype = get_reduced_dtype(dtype)
|
reduced_dtype = get_reduced_dtype(op, dtype)
|
||||||
reduced_shape = (shape[1 - axis],)
|
reduced_shape = (shape[1 - axis],)
|
||||||
|
|
||||||
if dtype.is_floating_point:
|
if dtype.is_floating_point:
|
||||||
@@ -123,8 +141,18 @@ def test_reduce2d(op, dtype, shape, axis):
|
|||||||
golden_z = torch.sum(x, dim=axis, keepdim=False, dtype=reduced_dtype)
|
golden_z = torch.sum(x, dim=axis, keepdim=False, dtype=reduced_dtype)
|
||||||
elif op == 'min':
|
elif op == 'min':
|
||||||
golden_z = torch.min(x, dim=axis, keepdim=False)[0].to(reduced_dtype)
|
golden_z = torch.min(x, dim=axis, keepdim=False)[0].to(reduced_dtype)
|
||||||
else:
|
elif op == 'max':
|
||||||
golden_z = torch.max(x, dim=axis, keepdim=False)[0].to(reduced_dtype)
|
golden_z = torch.max(x, dim=axis, keepdim=False)[0].to(reduced_dtype)
|
||||||
|
elif op == 'argmin':
|
||||||
|
golden_z = torch.argmin(x, dim=axis, keepdim=False).to(reduced_dtype)
|
||||||
|
elif op == 'argmax':
|
||||||
|
golden_z = torch.argmax(x, dim=axis, keepdim=False).to(reduced_dtype)
|
||||||
|
elif op == 'xor_sum':
|
||||||
|
sum_npy = np.bitwise_xor.reduce(x.cpu().numpy(), axis=axis, keepdims=False)
|
||||||
|
golden_z = torch.tensor(sum_npy, dtype=reduced_dtype).cuda()
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f'Unknwon reduce op {op}')
|
||||||
|
|
||||||
if dtype.is_floating_point and op == 'sum':
|
if dtype.is_floating_point and op == 'sum':
|
||||||
if shape[axis] >= 256:
|
if shape[axis] >= 256:
|
||||||
assert_close(z, golden_z, rtol=0.05, atol=0.1)
|
assert_close(z, golden_z, rtol=0.05, atol=0.1)
|
||||||
|
@@ -1041,6 +1041,13 @@ def max(input, axis, _builder=None):
|
|||||||
return semantic.max(input, axis, _builder)
|
return semantic.max(input, axis, _builder)
|
||||||
|
|
||||||
|
|
||||||
|
@builtin
|
||||||
|
@_add_reduction_docstr("maximum index")
|
||||||
|
def argmax(input, axis, _builder=None):
|
||||||
|
axis = _constexpr_to_value(axis)
|
||||||
|
return semantic.argmax(input, axis, _builder)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
@_add_reduction_docstr("minimum")
|
@_add_reduction_docstr("minimum")
|
||||||
def min(input, axis, _builder=None):
|
def min(input, axis, _builder=None):
|
||||||
@@ -1048,6 +1055,13 @@ def min(input, axis, _builder=None):
|
|||||||
return semantic.min(input, axis, _builder)
|
return semantic.min(input, axis, _builder)
|
||||||
|
|
||||||
|
|
||||||
|
@builtin
|
||||||
|
@_add_reduction_docstr("minimum index")
|
||||||
|
def argmin(input, axis, _builder=None):
|
||||||
|
axis = _constexpr_to_value(axis)
|
||||||
|
return semantic.argmin(input, axis, _builder)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
@_add_reduction_docstr("sum")
|
@_add_reduction_docstr("sum")
|
||||||
def sum(input, axis, _builder=None):
|
def sum(input, axis, _builder=None):
|
||||||
|
@@ -1061,10 +1061,18 @@ def min(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
|||||||
return reduce_impl(input, axis, builder, "min", ir.REDUCE_OP.FMIN, ir.REDUCE_OP.MIN)
|
return reduce_impl(input, axis, builder, "min", ir.REDUCE_OP.FMIN, ir.REDUCE_OP.MIN)
|
||||||
|
|
||||||
|
|
||||||
|
def argmin(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||||
|
return reduce_impl(input, axis, builder, "argmin", ir.REDUCE_OP.ARGFMIN, ir.REDUCE_OP.ARGMIN)
|
||||||
|
|
||||||
|
|
||||||
def max(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
def max(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||||
return reduce_impl(input, axis, builder, "max", ir.REDUCE_OP.FMAX, ir.REDUCE_OP.MAX)
|
return reduce_impl(input, axis, builder, "max", ir.REDUCE_OP.FMAX, ir.REDUCE_OP.MAX)
|
||||||
|
|
||||||
|
|
||||||
|
def argmax(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||||
|
return reduce_impl(input, axis, builder, "argmax", ir.REDUCE_OP.ARGFMAX, ir.REDUCE_OP.ARGMAX)
|
||||||
|
|
||||||
|
|
||||||
def sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
def sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||||
return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.FADD, ir.REDUCE_OP.ADD)
|
return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.FADD, ir.REDUCE_OP.ADD)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user