[Triton-MLIR][Backend] Add ReduceOpConversion into TritonGPUToLLVM conversion (#774)
What is done in this PR: - [x] Add `ConvertLayout`, `getSizePerThread` and `getShapePerCTA` implementation for `SliceEncodingAttr` - [x] Split `emitIndices` into two phases: `emitBaseIndexForBlockedLayout` and `emitOffsetForBlockedLayout` - [x] Add `ReduceOpConversion::matchAndRewriteBasic` implementation - [x] Add `ReduceOpConversion::matchAndRewriteFast` implementation with ptx instruction `shfl.sync` - [x] Add support for scalar value in `StoreOpConversion` - [x] Add Reduce1d and Reduce2d unit tests and pass all unit tests Co-authored-by: Qingyi Liu <liuqingyi1993@gmail.com>
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include <atomic>
|
||||
#include <limits>
|
||||
@@ -19,6 +20,8 @@ SmallVector<unsigned>
|
||||
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
unsigned &outVec);
|
||||
|
||||
SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op);
|
||||
|
||||
} // namespace triton
|
||||
|
||||
/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h
|
||||
|
@@ -250,6 +250,12 @@ struct PTXIOInstr : public PTXInstrBase<PTXIOInstr> {
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Add ".shared" suffix to instruction
|
||||
PTXIOInstr &shared(bool predicate = true) {
|
||||
o("shared", predicate);
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Add ".v" suffix to instruction
|
||||
PTXIOInstr &v(int vecWidth, bool predicate = true) {
|
||||
if (vecWidth > 1) {
|
||||
|
@@ -324,7 +324,9 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
|
||||
"Attribute":$parent
|
||||
);
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration;
|
||||
let extraClassDeclaration = extraBaseClassDeclaration # [{
|
||||
SmallVector<int64_t> paddedShape(ArrayRef<int64_t> shape) const;
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user