[Triton-MLIR][BACKEND] Refactor decompose insert_slice_async (#929)
1. Improve pipline's comment 2. Decompose insert_slice_async when load vector size is not supported 3. Add a test that could fail our gemm code Copy my comments here: There's a knob that may cause performance regression when decomposition has been performed. We should remove this knob once we have thorough analysis on async wait. Currently, we decompose `insert_slice_async` into `load` and `insert_slice` without knowing which `async_wait` is responsible for the `insert_slice_async`. To guarantee correctness, we blindly set the `async_wait` to wait for all async ops if any `insert_slice_async` has been decomposed. There are two options to improve this: 1. We can perform a dataflow analysis to find the `async_wait` that is responsible for the `insert_slice_async` in the backend. 4. We can modify the pipeline to perform the decomposition before the `async_wait` is inserted. However, it is also risky because we don't know the correct vectorized shape yet in the pipeline pass. Making the pipeline pass aware of the vectorization could introduce additional dependencies on the AxisInfoAnalysis and the Coalesce analysis.
This commit is contained in:
@@ -131,6 +131,12 @@ public:
|
||||
ChangeResult
|
||||
visitOperation(Operation *op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) override;
|
||||
|
||||
unsigned getPtrVectorSize(Value ptr);
|
||||
|
||||
unsigned getPtrAlignment(Value ptr);
|
||||
|
||||
unsigned getMaskAlignment(Value mask);
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
@@ -25,11 +25,11 @@ namespace gpu {
|
||||
|
||||
unsigned getElemsPerThread(Type type);
|
||||
|
||||
SmallVector<unsigned> getThreadsPerWarp(Attribute layout);
|
||||
SmallVector<unsigned> getThreadsPerWarp(const Attribute &layout);
|
||||
|
||||
SmallVector<unsigned> getWarpsPerCTA(Attribute layout);
|
||||
SmallVector<unsigned> getWarpsPerCTA(const Attribute &layout);
|
||||
|
||||
SmallVector<unsigned> getSizePerThread(Attribute layout);
|
||||
SmallVector<unsigned> getSizePerThread(const Attribute &layout);
|
||||
|
||||
SmallVector<unsigned> getContigPerThread(Attribute layout);
|
||||
|
||||
|
@@ -151,6 +151,10 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
|
||||
// attr-dict `:` type($src) `->` type($dst)
|
||||
//}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static DenseSet<unsigned> getEligibleLoadByteWidth(int computeCapability);
|
||||
}];
|
||||
|
||||
// The custom parser could be replaced with oilist in LLVM-16
|
||||
let parser = [{ return parseInsertSliceAsyncOp(parser, result); }];
|
||||
|
||||
|
@@ -276,4 +276,46 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
return result;
|
||||
}
|
||||
|
||||
unsigned AxisInfoAnalysis::getPtrVectorSize(Value ptr) {
|
||||
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorTy)
|
||||
return 1;
|
||||
auto layout = tensorTy.getEncoding();
|
||||
auto shape = tensorTy.getShape();
|
||||
|
||||
// Here order should be ordered by contiguous first, so the first element
|
||||
// should have the largest contiguous.
|
||||
auto order = triton::gpu::getOrder(layout);
|
||||
unsigned align = getPtrAlignment(ptr);
|
||||
|
||||
unsigned contigPerThread = triton::gpu::getSizePerThread(layout)[order[0]];
|
||||
unsigned vec = std::min(align, contigPerThread);
|
||||
vec = std::min<unsigned>(shape[order[0]], vec);
|
||||
|
||||
return vec;
|
||||
}
|
||||
|
||||
unsigned AxisInfoAnalysis::getPtrAlignment(Value ptr) {
|
||||
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorTy)
|
||||
return 1;
|
||||
auto axisInfo = lookupLatticeElement(ptr)->getValue();
|
||||
auto layout = tensorTy.getEncoding();
|
||||
auto order = triton::gpu::getOrder(layout);
|
||||
unsigned maxMultiple = axisInfo.getDivisibility(order[0]);
|
||||
unsigned maxContig = axisInfo.getContiguity(order[0]);
|
||||
unsigned alignment = std::min(maxMultiple, maxContig);
|
||||
return alignment;
|
||||
}
|
||||
|
||||
unsigned AxisInfoAnalysis::getMaskAlignment(Value mask) {
|
||||
auto tensorTy = mask.getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorTy)
|
||||
return 1;
|
||||
auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding());
|
||||
auto maskAxis = lookupLatticeElement(mask)->getValue();
|
||||
auto alignment = std::max<unsigned>(maskAxis.getConstancy(maskOrder[0]), 1);
|
||||
return alignment;
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
@@ -822,7 +822,7 @@ struct ArithConstantSplatOpConversion
|
||||
// Contains some helper functions for both Load and Store conversions.
|
||||
struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
|
||||
explicit LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass)
|
||||
: AxisAnalysisPass(axisAnalysisPass) {}
|
||||
: axisAnalysisPass(axisAnalysisPass) {}
|
||||
|
||||
// Get corresponding LLVM element values of \param value.
|
||||
static SmallVector<Value> getLLVMElems(Value value, Value llValue,
|
||||
@@ -838,51 +838,15 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
|
||||
}
|
||||
|
||||
unsigned getVectorSize(Value ptr) const {
|
||||
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorTy)
|
||||
return 1;
|
||||
auto layout = tensorTy.getEncoding();
|
||||
auto shape = tensorTy.getShape();
|
||||
|
||||
auto axisInfo = getAxisInfo(ptr);
|
||||
// Here order should be ordered by contiguous first, so the first element
|
||||
// should have the largest contiguous.
|
||||
auto order = getOrder(layout);
|
||||
unsigned align = getAlignment(ptr, layout);
|
||||
|
||||
unsigned contigPerThread = getSizePerThread(layout)[order[0]];
|
||||
unsigned vec = std::min(align, contigPerThread);
|
||||
vec = std::min<unsigned>(shape[order[0]], vec);
|
||||
|
||||
return vec;
|
||||
}
|
||||
|
||||
unsigned getAlignment(Value val, const Attribute &layout) const {
|
||||
auto axisInfo = getAxisInfo(val);
|
||||
auto order = getOrder(layout);
|
||||
unsigned maxMultiple = axisInfo->getDivisibility(order[0]);
|
||||
unsigned maxContig = axisInfo->getContiguity(order[0]);
|
||||
unsigned alignment = std::min(maxMultiple, maxContig);
|
||||
return alignment;
|
||||
return axisAnalysisPass.getPtrVectorSize(ptr);
|
||||
}
|
||||
|
||||
unsigned getMaskAlignment(Value mask) const {
|
||||
auto tensorTy = mask.getType().cast<RankedTensorType>();
|
||||
auto maskOrder = getOrder(tensorTy.getEncoding());
|
||||
auto maskAxis = getAxisInfo(mask);
|
||||
return std::max<int>(maskAxis->getConstancy(maskOrder[0]), 1);
|
||||
}
|
||||
|
||||
llvm::Optional<AxisInfo> getAxisInfo(Value val) const {
|
||||
if (auto it = AxisAnalysisPass.lookupLatticeElement(val)) {
|
||||
return it->getValue();
|
||||
}
|
||||
|
||||
return llvm::Optional<AxisInfo>{};
|
||||
return axisAnalysisPass.getMaskAlignment(mask);
|
||||
}
|
||||
|
||||
protected:
|
||||
AxisInfoAnalysis &AxisAnalysisPass;
|
||||
AxisInfoAnalysis &axisAnalysisPass;
|
||||
};
|
||||
|
||||
struct LoadOpConversion
|
||||
@@ -4604,30 +4568,68 @@ private:
|
||||
});
|
||||
}
|
||||
|
||||
void decomposeInsertSliceAsyncOp(ModuleOp mod,
|
||||
TritonGPUToLLVMTypeConverter &converter) {
|
||||
// cp.async is supported in Ampere and later
|
||||
if (computeCapability >= 80)
|
||||
return;
|
||||
|
||||
void decomposeInsertSliceAsyncOp(ModuleOp mod) {
|
||||
AxisInfoAnalysis axisInfoAnalysis(mod.getContext());
|
||||
axisInfoAnalysis.run(mod);
|
||||
// TODO(Keren): This is a hacky knob that may cause performance regression
|
||||
// when decomposition has been performed. We should remove this knob once we
|
||||
// have thorough analysis on async wait. Currently, we decompose
|
||||
// `insert_slice_async` into `load` and `insert_slice` without knowing which
|
||||
// `async_wait` is responsible for the `insert_slice_async`. To guarantee
|
||||
// correctness, we blindly set the `async_wait` to wait for all async ops.
|
||||
//
|
||||
// There are two options to improve this:
|
||||
// 1. We can perform a dataflow analysis to find the `async_wait` that is
|
||||
// responsible for the `insert_slice_async` in the backend.
|
||||
// 2. We can modify the pipeline to perform the decomposition before the
|
||||
// `async_wait` is inserted. However, it is also risky because we don't know
|
||||
// the correct vectorized shape yet in the pipeline pass. Making the
|
||||
// pipeline pass aware of the vectorization could introduce additional
|
||||
// dependencies on the AxisInfoAnalysis and the Coalesce analysis.
|
||||
bool decomposed = false;
|
||||
// insert_slice_async %src, %dst, %idx, %mask, %other
|
||||
// =>
|
||||
// %tmp = load %src, %mask, %other
|
||||
// %res = insert_slice %tmp into %dst[%idx]
|
||||
mod.walk([&](triton::gpu::InsertSliceAsyncOp insertSliceAsyncOp) -> void {
|
||||
OpBuilder builder(insertSliceAsyncOp);
|
||||
// load
|
||||
auto srcTy = insertSliceAsyncOp.src().getType().cast<RankedTensorType>();
|
||||
auto dstTy = insertSliceAsyncOp.getType().cast<RankedTensorType>();
|
||||
|
||||
// Get the vectorized load size
|
||||
auto src = insertSliceAsyncOp.src();
|
||||
auto dst = insertSliceAsyncOp.dst();
|
||||
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||
auto dstTy = dst.getType().cast<RankedTensorType>();
|
||||
auto srcBlocked =
|
||||
srcTy.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
||||
auto elemTy = converter.convertType(dstTy.getElementType());
|
||||
auto tmpTy = RankedTensorType::get(srcTy.getShape(), elemTy, srcBlocked);
|
||||
auto resSharedLayout =
|
||||
dstTy.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
|
||||
auto resElemTy = dstTy.getElementType();
|
||||
unsigned inVec = axisInfoAnalysis.getPtrVectorSize(src);
|
||||
unsigned outVec = resSharedLayout.getVec();
|
||||
unsigned minVec = std::min(outVec, inVec);
|
||||
auto maxBitWidth =
|
||||
std::max<unsigned>(128, resElemTy.getIntOrFloatBitWidth());
|
||||
auto vecBitWidth = resElemTy.getIntOrFloatBitWidth() * minVec;
|
||||
auto bitWidth = std::min<unsigned>(maxBitWidth, vecBitWidth);
|
||||
auto byteWidth = bitWidth / 8;
|
||||
|
||||
// If the load byte width is not eligible or the current compute
|
||||
// capability does not support async copy, then we do decompose
|
||||
if (triton::gpu::InsertSliceAsyncOp::getEligibleLoadByteWidth(
|
||||
computeCapability)
|
||||
.contains(byteWidth) &&
|
||||
computeCapability >= 80)
|
||||
return;
|
||||
|
||||
// load
|
||||
auto tmpTy =
|
||||
RankedTensorType::get(srcTy.getShape(), resElemTy, srcBlocked);
|
||||
auto loadOp = builder.create<triton::LoadOp>(
|
||||
insertSliceAsyncOp.getLoc(), tmpTy, insertSliceAsyncOp.src(),
|
||||
insertSliceAsyncOp.mask(), insertSliceAsyncOp.other(),
|
||||
insertSliceAsyncOp.cache(), insertSliceAsyncOp.evict(),
|
||||
insertSliceAsyncOp.isVolatile());
|
||||
|
||||
// insert_slice
|
||||
auto axis = insertSliceAsyncOp.axis();
|
||||
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
|
||||
@@ -4645,10 +4647,20 @@ private:
|
||||
// Replace
|
||||
insertSliceAsyncOp.replaceAllUsesWith(insertSliceOp.getResult());
|
||||
insertSliceAsyncOp.erase();
|
||||
decomposed = true;
|
||||
});
|
||||
|
||||
// async wait is supported in Ampere and later
|
||||
mod.walk([&](triton::gpu::AsyncWaitOp asyncWaitOp) -> void {
|
||||
asyncWaitOp.erase();
|
||||
if (computeCapability < 80) {
|
||||
asyncWaitOp.erase();
|
||||
} else if (decomposed) {
|
||||
OpBuilder builder(asyncWaitOp);
|
||||
// Wait for all previous async ops
|
||||
auto newAsyncWaitOp = builder.create<triton::gpu::AsyncWaitOp>(
|
||||
asyncWaitOp.getLoc(), builder.getI64IntegerAttr(0));
|
||||
asyncWaitOp.erase();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -4671,7 +4683,7 @@ public:
|
||||
|
||||
// step 1: Decompose unoptimized layout conversions to use shared memory
|
||||
// step 2: Decompose insert_slice_async to use load + insert_slice for
|
||||
// pre-Ampere architectures
|
||||
// pre-Ampere architectures or unsupported vectorized load sizes
|
||||
// step 3: Allocate shared memories and insert barriers
|
||||
// step 4: Convert SCF to CFG
|
||||
// step 5: Convert FuncOp to LLVMFuncOp via partial conversion
|
||||
@@ -4681,10 +4693,9 @@ public:
|
||||
// separation between 1/4 is that, step 3 is out of the scope of Dialect
|
||||
// Conversion, thus we need to make sure the smem is not revised during the
|
||||
// conversion of step 4.
|
||||
|
||||
decomposeBlockedToDotOperand(mod);
|
||||
|
||||
decomposeInsertSliceAsyncOp(mod, typeConverter);
|
||||
decomposeInsertSliceAsyncOp(mod);
|
||||
|
||||
Allocation allocation(mod);
|
||||
MembarAnalysis membar(&allocation);
|
||||
|
@@ -286,8 +286,8 @@ struct TritonAtomicCASPattern
|
||||
matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<triton::AtomicCASOp>(
|
||||
op, typeConverter->convertType(op.getType()),
|
||||
adaptor.ptr(), adaptor.cmp(), adaptor.val());
|
||||
op, typeConverter->convertType(op.getType()), adaptor.ptr(),
|
||||
adaptor.cmp(), adaptor.val());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@@ -241,7 +241,8 @@ mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
|
||||
auto argTy = arg.getType().cast<RankedTensorType>();
|
||||
auto argEltTy = argTy.getElementType();
|
||||
auto i32Ty = IntegerType::get(argEltTy.getContext(), 32);
|
||||
auto redOp = attributes.get("redOp").cast<mlir::triton::RedOpAttr>().getValue();
|
||||
auto redOp =
|
||||
attributes.get("redOp").cast<mlir::triton::RedOpAttr>().getValue();
|
||||
bool withIndex = mlir::triton::ReduceOp::withIndex(redOp);
|
||||
auto retEltTy = withIndex ? i32Ty : argEltTy;
|
||||
auto retShape = argTy.getShape().vec();
|
||||
|
@@ -71,7 +71,7 @@ unsigned getElemsPerThread(Type type) {
|
||||
return getElemsPerThread(tensorType.getEncoding(), tensorType.getShape());
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getThreadsPerWarp(Attribute layout) {
|
||||
SmallVector<unsigned> getThreadsPerWarp(const Attribute &layout) {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return SmallVector<unsigned>(blockedLayout.getThreadsPerWarp().begin(),
|
||||
blockedLayout.getThreadsPerWarp().end());
|
||||
@@ -86,7 +86,7 @@ SmallVector<unsigned> getThreadsPerWarp(Attribute layout) {
|
||||
return {};
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getWarpsPerCTA(Attribute layout) {
|
||||
SmallVector<unsigned> getWarpsPerCTA(const Attribute &layout) {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return SmallVector<unsigned>(blockedLayout.getWarpsPerCTA().begin(),
|
||||
blockedLayout.getWarpsPerCTA().end());
|
||||
@@ -99,7 +99,7 @@ SmallVector<unsigned> getWarpsPerCTA(Attribute layout) {
|
||||
return {};
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getSizePerThread(Attribute layout) {
|
||||
SmallVector<unsigned> getSizePerThread(const Attribute &layout) {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
|
||||
blockedLayout.getSizePerThread().end());
|
||||
@@ -659,6 +659,15 @@ void printInsertSliceAsyncOp(OpAsmPrinter &printer,
|
||||
printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType());
|
||||
}
|
||||
|
||||
DenseSet<unsigned>
|
||||
InsertSliceAsyncOp::getEligibleLoadByteWidth(int computeCapability) {
|
||||
DenseSet<unsigned> validLoadBytes;
|
||||
if (computeCapability >= 80) {
|
||||
validLoadBytes = {4, 8, 16};
|
||||
}
|
||||
return validLoadBytes;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ASM Interface (i.e.: alias)
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@@ -25,18 +25,20 @@ static Type getI1SameShape(Value v) {
|
||||
tensorType.getEncoding());
|
||||
}
|
||||
|
||||
#define int_attr(num) builder.getI64IntegerAttr(num)
|
||||
|
||||
namespace {
|
||||
|
||||
class LoopPipeliner {
|
||||
/// cache forOp we are working on
|
||||
/// Cache forOp we are working on
|
||||
scf::ForOp forOp;
|
||||
|
||||
/// cache YieldOp for this forOp
|
||||
/// Cache YieldOp for this forOp
|
||||
scf::YieldOp yieldOp;
|
||||
|
||||
/// loads to be pipelined
|
||||
/// Loads to be pipelined
|
||||
SetVector<Value> loads;
|
||||
/// the value that each load will be mapped to (after layout conversion)
|
||||
/// The value that each load will be mapped to (after layout conversion)
|
||||
DenseMap<Value, Value> loadsMapping;
|
||||
/// load => buffer
|
||||
DenseMap<Value, Value> loadsBuffer;
|
||||
@@ -51,7 +53,7 @@ class LoopPipeliner {
|
||||
///
|
||||
Value loopIterIdx;
|
||||
|
||||
/// comments on numStages:
|
||||
/// Comments on numStages:
|
||||
/// [0, numStages-1) are in the prologue
|
||||
/// numStages-1 is appended after the loop body
|
||||
int numStages;
|
||||
@@ -61,6 +63,7 @@ class LoopPipeliner {
|
||||
|
||||
/// Block arguments that loads depend on
|
||||
DenseSet<BlockArgument> depArgs;
|
||||
|
||||
/// Operations (inside the loop body) that loads depend on
|
||||
DenseSet<Operation *> depOps;
|
||||
|
||||
@@ -71,7 +74,7 @@ class LoopPipeliner {
|
||||
|
||||
Value lookupOrDefault(Value origin, int stage);
|
||||
|
||||
/// returns a empty buffer of size <numStages, ...>
|
||||
/// Returns a empty buffer of size <numStages, ...>
|
||||
ttg::AllocTensorOp allocateEmptyBuffer(Operation *op, OpBuilder &builder);
|
||||
|
||||
public:
|
||||
@@ -84,7 +87,7 @@ public:
|
||||
/// Collect loads to pipeline. Return success if we can pipeline this loop
|
||||
LogicalResult initialize();
|
||||
|
||||
/// emit pipelined loads (before loop body)
|
||||
/// Emit pipelined loads (before loop body)
|
||||
void emitPrologue();
|
||||
|
||||
/// emit pipelined loads (after loop body)
|
||||
@@ -134,7 +137,7 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
|
||||
|
||||
ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(Operation *op,
|
||||
OpBuilder &builder) {
|
||||
// allocate a buffer for each pipelined tensor
|
||||
// Allocate a buffer for each pipelined tensor
|
||||
// shape: e.g. (numStages==4), <32x64xbf16> -> <4x32x64xbf16>
|
||||
Value convertLayout = loadsMapping[op->getResult(0)];
|
||||
if (auto tensorType = convertLayout.getType().dyn_cast<RankedTensorType>()) {
|
||||
@@ -215,9 +218,9 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
loads.insert(loadOp);
|
||||
}
|
||||
|
||||
// we have some loads to pipeline
|
||||
// We have some loads to pipeline
|
||||
if (!loads.empty()) {
|
||||
// update depArgs & depOps
|
||||
// Update depArgs & depOps
|
||||
for (Value loadOp : loads) {
|
||||
for (Value dep : loadDeps[loadOp]) {
|
||||
// TODO: we should record the stage that the value is depended on
|
||||
@@ -244,23 +247,20 @@ void LoopPipeliner::emitPrologue() {
|
||||
setValueMapping(arg, operand.get(), 0);
|
||||
}
|
||||
|
||||
// helper to construct int attribute
|
||||
auto intAttr = [&](int64_t val) { return builder.getI64IntegerAttr(val); };
|
||||
|
||||
// prologue from [0, numStage-1)
|
||||
Value iv = forOp.getLowerBound();
|
||||
pipelineIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
|
||||
for (int stage = 0; stage < numStages - 1; ++stage) {
|
||||
// special handling for induction variable as the increment is implicit
|
||||
// Special handling for induction variable as the increment is implicit
|
||||
if (stage != 0)
|
||||
iv = builder.create<arith::AddIOp>(iv.getLoc(), iv, forOp.getStep());
|
||||
setValueMapping(forOp.getInductionVar(), iv, stage);
|
||||
|
||||
// special handling for loop condition as there is no condition in ForOp
|
||||
// Special handling for loop condition as there is no condition in ForOp
|
||||
Value loopCond = builder.create<arith::CmpIOp>(
|
||||
iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound());
|
||||
|
||||
// rematerialize peeled values
|
||||
// Rematerialize peeled values
|
||||
SmallVector<Operation *> orderedDeps;
|
||||
for (Operation &op : forOp.getLoopBody().front()) {
|
||||
if (depOps.contains(&op))
|
||||
@@ -314,7 +314,7 @@ void LoopPipeliner::emitPrologue() {
|
||||
}
|
||||
}
|
||||
|
||||
// update mapping of results
|
||||
// Update mapping of results
|
||||
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
|
||||
Value originalResult = op->getResult(dstIdx);
|
||||
// copy_async will update the value of its only use
|
||||
@@ -350,13 +350,14 @@ void LoopPipeliner::emitPrologue() {
|
||||
loadsBufferType[loadOp].getEncoding());
|
||||
Value extractSlice = builder.create<tensor::ExtractSliceOp>(
|
||||
loadOp.getLoc(), sliceType, loadStageBuffer[loadOp][numStages - 1],
|
||||
SmallVector<OpFoldResult>{intAttr(0), intAttr(0), intAttr(0)},
|
||||
SmallVector<OpFoldResult>{intAttr(1), intAttr(sliceType.getShape()[0]),
|
||||
intAttr(sliceType.getShape()[1])},
|
||||
SmallVector<OpFoldResult>{intAttr(1), intAttr(1), intAttr(1)});
|
||||
SmallVector<OpFoldResult>{int_attr(0), int_attr(0), int_attr(0)},
|
||||
SmallVector<OpFoldResult>{int_attr(1),
|
||||
int_attr(sliceType.getShape()[0]),
|
||||
int_attr(sliceType.getShape()[1])},
|
||||
SmallVector<OpFoldResult>{int_attr(1), int_attr(1), int_attr(1)});
|
||||
loadsExtract[loadOp] = extractSlice;
|
||||
}
|
||||
// bump up loopIterIdx, this is used for getting the correct slice for the
|
||||
// Bump up loopIterIdx, this is used for getting the correct slice for the
|
||||
// *next* iteration
|
||||
loopIterIdx = builder.create<arith::AddIOp>(
|
||||
loopIterIdx.getLoc(), loopIterIdx,
|
||||
@@ -365,9 +366,6 @@ void LoopPipeliner::emitPrologue() {
|
||||
|
||||
void LoopPipeliner::emitEpilogue() {
|
||||
// If there's any outstanding async copies, we need to wait for them.
|
||||
// TODO(Keren): We may want to completely avoid the async copies in the last
|
||||
// few iterations by setting is_masked attribute to true. We don't want to use
|
||||
// the mask operand because it's a tensor but not a scalar.
|
||||
OpBuilder builder(forOp);
|
||||
OpBuilder::InsertionGuard g(builder);
|
||||
builder.setInsertionPointAfter(forOp);
|
||||
@@ -376,9 +374,8 @@ void LoopPipeliner::emitEpilogue() {
|
||||
|
||||
scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
OpBuilder builder(forOp);
|
||||
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
|
||||
|
||||
// order of new args:
|
||||
// Order of new args:
|
||||
// (original args),
|
||||
// (insertSliceAsync buffer at stage numStages - 1) for each load
|
||||
// (extracted tensor) for each load
|
||||
@@ -465,7 +462,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]);
|
||||
++argIdx;
|
||||
}
|
||||
// special handling for iv & loop condition
|
||||
// Special handling for iv & loop condition
|
||||
Value nextIV = builder.create<arith::AddIOp>(
|
||||
newForOp.getInductionVar().getLoc(),
|
||||
newForOp.getRegionIterArgs()[nextIVIdx], newForOp.getStep());
|
||||
@@ -473,7 +470,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
builder.create<arith::CmpIOp>(nextIV.getLoc(), arith::CmpIPredicate::slt,
|
||||
nextIV, newForOp.getUpperBound());
|
||||
|
||||
// slice index
|
||||
// Slice index
|
||||
SmallVector<Value> nextBuffers;
|
||||
SmallVector<Value> extractSlices;
|
||||
|
||||
@@ -490,7 +487,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
|
||||
for (Operation *op : orderedDeps) {
|
||||
Operation *nextOp = nullptr;
|
||||
// update loading mask
|
||||
// Update loading mask
|
||||
if (loads.contains(op->getResult(0))) {
|
||||
auto loadOp = llvm::cast<triton::LoadOp>(op);
|
||||
Value mask = loadOp.mask();
|
||||
@@ -500,7 +497,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
mask.getLoc(), mask.getType(), nextLoopCond);
|
||||
newMask = builder.create<arith::AndIOp>(
|
||||
mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask));
|
||||
// if mask is defined outside the loop, don't update the map more than
|
||||
// If mask is defined outside the loop, don't update the map more than
|
||||
// once
|
||||
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
|
||||
nextMapping.map(mask, newMask);
|
||||
@@ -522,18 +519,19 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
loadsBufferType[loadOp].getEncoding());
|
||||
nextOp = builder.create<tensor::ExtractSliceOp>(
|
||||
op->getLoc(), sliceType, insertAsyncOp,
|
||||
SmallVector<OpFoldResult>{extractSliceIndex, intAttr(0), intAttr(0)},
|
||||
SmallVector<OpFoldResult>{intAttr(1),
|
||||
intAttr(sliceType.getShape()[0]),
|
||||
intAttr(sliceType.getShape()[1])},
|
||||
SmallVector<OpFoldResult>{intAttr(1), intAttr(1), intAttr(1)});
|
||||
SmallVector<OpFoldResult>{extractSliceIndex, int_attr(0),
|
||||
int_attr(0)},
|
||||
SmallVector<OpFoldResult>{int_attr(1),
|
||||
int_attr(sliceType.getShape()[0]),
|
||||
int_attr(sliceType.getShape()[1])},
|
||||
SmallVector<OpFoldResult>{int_attr(1), int_attr(1), int_attr(1)});
|
||||
extractSlices.push_back(nextOp->getResult(0));
|
||||
} else
|
||||
nextOp = builder.clone(*op, nextMapping);
|
||||
// update mapping of results
|
||||
// Update mapping of results
|
||||
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
|
||||
nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx));
|
||||
// if this is a loop-carried value, update the mapping for yield
|
||||
// If this is a loop-carried value, update the mapping for yield
|
||||
auto originYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
for (OpOperand &operand : originYield->getOpOperands()) {
|
||||
if (operand.get() == op->getResult(dstIdx)) {
|
||||
@@ -583,7 +581,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
it->getDefiningOp()->moveAfter(asyncWait);
|
||||
}
|
||||
|
||||
// bump iteration count
|
||||
// Bump iteration count
|
||||
pipelineIterIdx = builder.create<arith::AddIOp>(
|
||||
nextIV.getLoc(), pipelineIterIdx,
|
||||
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), 1, 32));
|
||||
|
@@ -341,7 +341,7 @@ void init_triton_ir(py::module &&m) {
|
||||
return funcs[0];
|
||||
});
|
||||
|
||||
m.def("make_attr",
|
||||
m.def("make_attr",
|
||||
[](const std::vector<int> &values, mlir::MLIRContext &context) {
|
||||
return mlir::DenseIntElementsAttr::get(
|
||||
mlir::RankedTensorType::get(
|
||||
@@ -1113,7 +1113,8 @@ void init_triton_ir(py::module &&m) {
|
||||
mlir::Value &val) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
mlir::Type dstType;
|
||||
if (auto srcTensorType = ptr.getType().dyn_cast<mlir::RankedTensorType>()) {
|
||||
if (auto srcTensorType =
|
||||
ptr.getType().dyn_cast<mlir::RankedTensorType>()) {
|
||||
mlir::Type dstElemType = srcTensorType.getElementType()
|
||||
.cast<mlir::triton::PointerType>()
|
||||
.getPointeeType();
|
||||
|
@@ -172,8 +172,9 @@ def get_proper_err(a, b, golden):
|
||||
[128, 64, 128, 4, 128, 64, 128, False, False],
|
||||
[16, 16, 16, 16, 16, 16, 16, False, False], # wpt overflow issue
|
||||
# K-Forloop
|
||||
[32, 32, 64, 4, 32, 32, 32, False, False], # Single shared encoding
|
||||
[16, 16, 128, 4, 16, 16, 16, False, False], # Single shared encoding and small k
|
||||
#[16, 16, 64, 4, 8, 8, 8, False, False], # Wrap threads
|
||||
[32, 32, 64, 4, 32, 32, 32, False, False], # Single shared encoding
|
||||
[16, 16, 128, 4, 16, 16, 16, False, False], # Single shared encoding and small k
|
||||
[64, 32, 128, 4, 64, 32, 64, False, False],
|
||||
[128, 16, 128, 4, 128, 16, 32, False, False],
|
||||
[32, 16, 128, 4, 32, 16, 32, False, False],
|
||||
|
@@ -387,6 +387,45 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
// -----
|
||||
|
||||
#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [4], warpsPerCTA = [4], order = [0]}>
|
||||
#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}>
|
||||
#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}>
|
||||
#slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}>
|
||||
#AL = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_insert_slice_async_fallback
|
||||
func @basic_insert_slice_async_fallback(%arg0: !tt.ptr<f16> {tt.divisibility = 1 : i32}) {
|
||||
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1>
|
||||
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0>
|
||||
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2>
|
||||
%off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<64xi32, #slice3d0>) -> tensor<1x64xi32, #block3>
|
||||
%broadcast_off0_scalar = tt.broadcast %off0 : (tensor<16x1xi32, #block2>) -> tensor<16x64xi32, #block2>
|
||||
%cst_scalar = arith.constant 64 : i32
|
||||
%cst = tt.splat %cst_scalar : (i32) -> tensor<16x64xi32, #block2>
|
||||
%broadcast_off0_ = arith.muli %broadcast_off0_scalar, %cst : tensor<16x64xi32, #block2>
|
||||
%broadcast_off1_ = tt.broadcast %off1 : (tensor<1x64xi32, #block3>) -> tensor<16x64xi32, #block3>
|
||||
%broadcast_off0 = triton_gpu.convert_layout %broadcast_off0_ : (tensor<16x64xi32, #block2>) -> tensor<16x64xi32, #AL>
|
||||
%broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : (tensor<16x64xi32, #block3>) -> tensor<16x64xi32, #AL>
|
||||
%off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x64xi32, #AL>
|
||||
%a_init = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<16x64x!tt.ptr<f16>, #AL>
|
||||
%a_ptr = tt.addptr %a_init, %off : tensor<16x64x!tt.ptr<f16>, #AL>
|
||||
%tensor = triton_gpu.alloc_tensor : tensor<2x16x64xf16, #A>
|
||||
%index = arith.constant 1 : i32
|
||||
|
||||
// CHECK: llvm.load
|
||||
// CHECK-SAME: !llvm.ptr<vector<8xi32>, 3>
|
||||
// CHECK: llvm.load
|
||||
// CHECK-SAME: !llvm.ptr<vector<8xi32>, 3>
|
||||
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64x!tt.ptr<f16>, #AL> -> tensor<2x16x64xf16, #A>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [4], warpsPerCTA = [4], order = [0]}>
|
||||
#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}>
|
||||
#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
|
Reference in New Issue
Block a user