[Triton-MLIR][BACKEND] Refactor decompose insert_slice_async (#929)
1. Improve pipline's comment 2. Decompose insert_slice_async when load vector size is not supported 3. Add a test that could fail our gemm code Copy my comments here: There's a knob that may cause performance regression when decomposition has been performed. We should remove this knob once we have thorough analysis on async wait. Currently, we decompose `insert_slice_async` into `load` and `insert_slice` without knowing which `async_wait` is responsible for the `insert_slice_async`. To guarantee correctness, we blindly set the `async_wait` to wait for all async ops if any `insert_slice_async` has been decomposed. There are two options to improve this: 1. We can perform a dataflow analysis to find the `async_wait` that is responsible for the `insert_slice_async` in the backend. 4. We can modify the pipeline to perform the decomposition before the `async_wait` is inserted. However, it is also risky because we don't know the correct vectorized shape yet in the pipeline pass. Making the pipeline pass aware of the vectorization could introduce additional dependencies on the AxisInfoAnalysis and the Coalesce analysis.
This commit is contained in:
@@ -822,7 +822,7 @@ struct ArithConstantSplatOpConversion
|
||||
// Contains some helper functions for both Load and Store conversions.
|
||||
struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
|
||||
explicit LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass)
|
||||
: AxisAnalysisPass(axisAnalysisPass) {}
|
||||
: axisAnalysisPass(axisAnalysisPass) {}
|
||||
|
||||
// Get corresponding LLVM element values of \param value.
|
||||
static SmallVector<Value> getLLVMElems(Value value, Value llValue,
|
||||
@@ -838,51 +838,15 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
|
||||
}
|
||||
|
||||
unsigned getVectorSize(Value ptr) const {
|
||||
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorTy)
|
||||
return 1;
|
||||
auto layout = tensorTy.getEncoding();
|
||||
auto shape = tensorTy.getShape();
|
||||
|
||||
auto axisInfo = getAxisInfo(ptr);
|
||||
// Here order should be ordered by contiguous first, so the first element
|
||||
// should have the largest contiguous.
|
||||
auto order = getOrder(layout);
|
||||
unsigned align = getAlignment(ptr, layout);
|
||||
|
||||
unsigned contigPerThread = getSizePerThread(layout)[order[0]];
|
||||
unsigned vec = std::min(align, contigPerThread);
|
||||
vec = std::min<unsigned>(shape[order[0]], vec);
|
||||
|
||||
return vec;
|
||||
}
|
||||
|
||||
unsigned getAlignment(Value val, const Attribute &layout) const {
|
||||
auto axisInfo = getAxisInfo(val);
|
||||
auto order = getOrder(layout);
|
||||
unsigned maxMultiple = axisInfo->getDivisibility(order[0]);
|
||||
unsigned maxContig = axisInfo->getContiguity(order[0]);
|
||||
unsigned alignment = std::min(maxMultiple, maxContig);
|
||||
return alignment;
|
||||
return axisAnalysisPass.getPtrVectorSize(ptr);
|
||||
}
|
||||
|
||||
unsigned getMaskAlignment(Value mask) const {
|
||||
auto tensorTy = mask.getType().cast<RankedTensorType>();
|
||||
auto maskOrder = getOrder(tensorTy.getEncoding());
|
||||
auto maskAxis = getAxisInfo(mask);
|
||||
return std::max<int>(maskAxis->getConstancy(maskOrder[0]), 1);
|
||||
}
|
||||
|
||||
llvm::Optional<AxisInfo> getAxisInfo(Value val) const {
|
||||
if (auto it = AxisAnalysisPass.lookupLatticeElement(val)) {
|
||||
return it->getValue();
|
||||
}
|
||||
|
||||
return llvm::Optional<AxisInfo>{};
|
||||
return axisAnalysisPass.getMaskAlignment(mask);
|
||||
}
|
||||
|
||||
protected:
|
||||
AxisInfoAnalysis &AxisAnalysisPass;
|
||||
AxisInfoAnalysis &axisAnalysisPass;
|
||||
};
|
||||
|
||||
struct LoadOpConversion
|
||||
@@ -4604,30 +4568,68 @@ private:
|
||||
});
|
||||
}
|
||||
|
||||
void decomposeInsertSliceAsyncOp(ModuleOp mod,
|
||||
TritonGPUToLLVMTypeConverter &converter) {
|
||||
// cp.async is supported in Ampere and later
|
||||
if (computeCapability >= 80)
|
||||
return;
|
||||
|
||||
void decomposeInsertSliceAsyncOp(ModuleOp mod) {
|
||||
AxisInfoAnalysis axisInfoAnalysis(mod.getContext());
|
||||
axisInfoAnalysis.run(mod);
|
||||
// TODO(Keren): This is a hacky knob that may cause performance regression
|
||||
// when decomposition has been performed. We should remove this knob once we
|
||||
// have thorough analysis on async wait. Currently, we decompose
|
||||
// `insert_slice_async` into `load` and `insert_slice` without knowing which
|
||||
// `async_wait` is responsible for the `insert_slice_async`. To guarantee
|
||||
// correctness, we blindly set the `async_wait` to wait for all async ops.
|
||||
//
|
||||
// There are two options to improve this:
|
||||
// 1. We can perform a dataflow analysis to find the `async_wait` that is
|
||||
// responsible for the `insert_slice_async` in the backend.
|
||||
// 2. We can modify the pipeline to perform the decomposition before the
|
||||
// `async_wait` is inserted. However, it is also risky because we don't know
|
||||
// the correct vectorized shape yet in the pipeline pass. Making the
|
||||
// pipeline pass aware of the vectorization could introduce additional
|
||||
// dependencies on the AxisInfoAnalysis and the Coalesce analysis.
|
||||
bool decomposed = false;
|
||||
// insert_slice_async %src, %dst, %idx, %mask, %other
|
||||
// =>
|
||||
// %tmp = load %src, %mask, %other
|
||||
// %res = insert_slice %tmp into %dst[%idx]
|
||||
mod.walk([&](triton::gpu::InsertSliceAsyncOp insertSliceAsyncOp) -> void {
|
||||
OpBuilder builder(insertSliceAsyncOp);
|
||||
// load
|
||||
auto srcTy = insertSliceAsyncOp.src().getType().cast<RankedTensorType>();
|
||||
auto dstTy = insertSliceAsyncOp.getType().cast<RankedTensorType>();
|
||||
|
||||
// Get the vectorized load size
|
||||
auto src = insertSliceAsyncOp.src();
|
||||
auto dst = insertSliceAsyncOp.dst();
|
||||
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||
auto dstTy = dst.getType().cast<RankedTensorType>();
|
||||
auto srcBlocked =
|
||||
srcTy.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
||||
auto elemTy = converter.convertType(dstTy.getElementType());
|
||||
auto tmpTy = RankedTensorType::get(srcTy.getShape(), elemTy, srcBlocked);
|
||||
auto resSharedLayout =
|
||||
dstTy.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
|
||||
auto resElemTy = dstTy.getElementType();
|
||||
unsigned inVec = axisInfoAnalysis.getPtrVectorSize(src);
|
||||
unsigned outVec = resSharedLayout.getVec();
|
||||
unsigned minVec = std::min(outVec, inVec);
|
||||
auto maxBitWidth =
|
||||
std::max<unsigned>(128, resElemTy.getIntOrFloatBitWidth());
|
||||
auto vecBitWidth = resElemTy.getIntOrFloatBitWidth() * minVec;
|
||||
auto bitWidth = std::min<unsigned>(maxBitWidth, vecBitWidth);
|
||||
auto byteWidth = bitWidth / 8;
|
||||
|
||||
// If the load byte width is not eligible or the current compute
|
||||
// capability does not support async copy, then we do decompose
|
||||
if (triton::gpu::InsertSliceAsyncOp::getEligibleLoadByteWidth(
|
||||
computeCapability)
|
||||
.contains(byteWidth) &&
|
||||
computeCapability >= 80)
|
||||
return;
|
||||
|
||||
// load
|
||||
auto tmpTy =
|
||||
RankedTensorType::get(srcTy.getShape(), resElemTy, srcBlocked);
|
||||
auto loadOp = builder.create<triton::LoadOp>(
|
||||
insertSliceAsyncOp.getLoc(), tmpTy, insertSliceAsyncOp.src(),
|
||||
insertSliceAsyncOp.mask(), insertSliceAsyncOp.other(),
|
||||
insertSliceAsyncOp.cache(), insertSliceAsyncOp.evict(),
|
||||
insertSliceAsyncOp.isVolatile());
|
||||
|
||||
// insert_slice
|
||||
auto axis = insertSliceAsyncOp.axis();
|
||||
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
|
||||
@@ -4645,10 +4647,20 @@ private:
|
||||
// Replace
|
||||
insertSliceAsyncOp.replaceAllUsesWith(insertSliceOp.getResult());
|
||||
insertSliceAsyncOp.erase();
|
||||
decomposed = true;
|
||||
});
|
||||
|
||||
// async wait is supported in Ampere and later
|
||||
mod.walk([&](triton::gpu::AsyncWaitOp asyncWaitOp) -> void {
|
||||
asyncWaitOp.erase();
|
||||
if (computeCapability < 80) {
|
||||
asyncWaitOp.erase();
|
||||
} else if (decomposed) {
|
||||
OpBuilder builder(asyncWaitOp);
|
||||
// Wait for all previous async ops
|
||||
auto newAsyncWaitOp = builder.create<triton::gpu::AsyncWaitOp>(
|
||||
asyncWaitOp.getLoc(), builder.getI64IntegerAttr(0));
|
||||
asyncWaitOp.erase();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -4671,7 +4683,7 @@ public:
|
||||
|
||||
// step 1: Decompose unoptimized layout conversions to use shared memory
|
||||
// step 2: Decompose insert_slice_async to use load + insert_slice for
|
||||
// pre-Ampere architectures
|
||||
// pre-Ampere architectures or unsupported vectorized load sizes
|
||||
// step 3: Allocate shared memories and insert barriers
|
||||
// step 4: Convert SCF to CFG
|
||||
// step 5: Convert FuncOp to LLVMFuncOp via partial conversion
|
||||
@@ -4681,10 +4693,9 @@ public:
|
||||
// separation between 1/4 is that, step 3 is out of the scope of Dialect
|
||||
// Conversion, thus we need to make sure the smem is not revised during the
|
||||
// conversion of step 4.
|
||||
|
||||
decomposeBlockedToDotOperand(mod);
|
||||
|
||||
decomposeInsertSliceAsyncOp(mod, typeConverter);
|
||||
decomposeInsertSliceAsyncOp(mod);
|
||||
|
||||
Allocation allocation(mod);
|
||||
MembarAnalysis membar(&allocation);
|
||||
|
@@ -286,8 +286,8 @@ struct TritonAtomicCASPattern
|
||||
matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<triton::AtomicCASOp>(
|
||||
op, typeConverter->convertType(op.getType()),
|
||||
adaptor.ptr(), adaptor.cmp(), adaptor.val());
|
||||
op, typeConverter->convertType(op.getType()), adaptor.ptr(),
|
||||
adaptor.cmp(), adaptor.val());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
Reference in New Issue
Block a user