[Triton-MLIR][BACKEND] Fix the membar pass to add missing barriers caused by scf.for (#933)

1. Add missing barriers and revert the previous temporary solution
2. Extract the `run` method from membar analysis because the membar
analysis should have two phases, including construction, which doesn't
modify any IR, and modification, which adds barrier IRs. Hope this could
make the use of membar clear.
This commit is contained in:
Keren Zhou
2022-12-01 11:54:18 -08:00
committed by GitHub
parent 9def1bcebf
commit c280ebda1b
8 changed files with 170 additions and 102 deletions

View File

@@ -644,7 +644,6 @@ public:
return multiDimIdx;
}
struct SmallVectorKeyInfo {
static unsigned getHashValue(const SmallVector<unsigned> &key) {
return llvm::hash_combine_range(key.begin(), key.end());
@@ -672,7 +671,7 @@ public:
emitIndices(loc, rewriter, parent, sliceLayout.paddedShape(shape));
unsigned numIndices = parentIndices.size();
SmallVector<SmallVector<Value>> resultIndices;
for (unsigned i = 0; i < numIndices; ++i){
for (unsigned i = 0; i < numIndices; ++i) {
SmallVector<Value> indices = parentIndices[i];
indices.erase(indices.begin() + dim);
resultIndices.push_back(indices);
@@ -1203,14 +1202,14 @@ struct BroadcastOpConversion
auto resultOffsets = emitOffsetForLayout(resultLayout, resultShape);
SmallVector<Value> srcVals = getElementsFromStruct(loc, src, rewriter);
DenseMap<SmallVector<unsigned>, Value, SmallVectorKeyInfo> srcValues;
for(size_t i = 0; i < srcOffsets.size(); i++){
for (size_t i = 0; i < srcOffsets.size(); i++) {
srcValues[srcOffsets[i]] = srcVals[i];
}
SmallVector<Value> resultVals;
for(size_t i = 0; i < resultOffsets.size(); i++) {
for (size_t i = 0; i < resultOffsets.size(); i++) {
auto offset = resultOffsets[i];
for(size_t j = 0; j < srcShape.size(); j++)
if(srcShape[j]==1)
for (size_t j = 0; j < srcShape.size(); j++)
if (srcShape[j] == 1)
offset[j] = 0;
resultVals.push_back(srcValues.lookup(offset));
}
@@ -1940,8 +1939,8 @@ struct MakeRangeOpConversion
unsigned elems = idxs.size();
SmallVector<Value> retVals(elems);
// TODO: slice layout has more elements than expected.
// Unexpected behavior for make range, but genereally ok when followed by expand dims + broadcast.
// very weird behavior otherwise potentially.
// Unexpected behavior for make range, but genereally ok when followed by
// expand dims + broadcast. very weird behavior otherwise potentially.
for (const auto multiDim : llvm::enumerate(idxs)) {
assert(multiDim.value().size() == 1);
retVals[multiDim.index()] = add(multiDim.value()[0], start);
@@ -2647,13 +2646,13 @@ public:
}
// dot_op<opIdx=0, parent=#mma> = #mma
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
if(srcLayout.isa<MmaEncodingAttr>() &&
if (srcLayout.isa<MmaEncodingAttr>() &&
dstLayout.isa<DotOperandEncodingAttr>()) {
auto srcMmaLayout = srcLayout.cast<MmaEncodingAttr>();
auto dstDotLayout = dstLayout.cast<DotOperandEncodingAttr>();
if(srcMmaLayout.getWarpsPerCTA()[1] == 1 &&
dstDotLayout.getOpIdx() == 0 &&
dstDotLayout.getParent() == srcMmaLayout) {
if (srcMmaLayout.getWarpsPerCTA()[1] == 1 &&
dstDotLayout.getOpIdx() == 0 &&
dstDotLayout.getParent() == srcMmaLayout) {
// get source values
Location loc = op->getLoc();
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
@@ -2662,35 +2661,37 @@ public:
this->getTypeConverter()->convertType(srcTy.getElementType());
// for the destination type, we need to pack values together
// so they can be consumed by tensor core operations
unsigned vecSize = std::max<unsigned>(32 / elemTy.getIntOrFloatBitWidth(), 1);
unsigned vecSize =
std::max<unsigned>(32 / elemTy.getIntOrFloatBitWidth(), 1);
Type vecTy = vec_ty(elemTy, vecSize);
SmallVector<Type> types(elems/vecSize, vecTy);
SmallVector<Type> types(elems / vecSize, vecTy);
SmallVector<Value> vecVals;
for(unsigned i = 0; i < elems; i += vecSize) {
for (unsigned i = 0; i < elems; i += vecSize) {
Value packed = rewriter.create<LLVM::UndefOp>(loc, vecTy);
for(unsigned j = 0; j < vecSize; j++)
packed = insert_element(vecTy, packed, vals[i+j], i32_val(j));
for (unsigned j = 0; j < vecSize; j++)
packed = insert_element(vecTy, packed, vals[i + j], i32_val(j));
vecVals.push_back(packed);
}
// This needs to be ordered the same way that
// ldmatrix.x4 would order it
// TODO: this needs to be refactor so we don't
// implicitly depends on how emitOffsetsForMMAV2
// is implemented
SmallVector<Value> reorderedVals;
for(unsigned i = 0; i < vecVals.size(); i += 4) {
for (unsigned i = 0; i < vecVals.size(); i += 4) {
reorderedVals.push_back(vecVals[i]);
reorderedVals.push_back(vecVals[i+2]);
reorderedVals.push_back(vecVals[i+1]);
reorderedVals.push_back(vecVals[i+3]);
reorderedVals.push_back(vecVals[i + 2]);
reorderedVals.push_back(vecVals[i + 1]);
reorderedVals.push_back(vecVals[i + 3]);
}
// return composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK);
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
Value view = getStructFromElements(loc, reorderedVals, rewriter, structTy);
Type structTy =
LLVM::LLVMStructType::getLiteral(this->getContext(), types);
Value view =
getStructFromElements(loc, reorderedVals, rewriter, structTy);
rewriter.replaceOp(op, view);
return success();
}
@@ -3138,10 +3139,6 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
// TODO[Keren]: A temporary workaround for an issue from membar pass.
// https://triton-lang.slack.com/archives/C042VBSQWNS/p1669796615860699?thread_ts=1669779203.526739&cid=C042VBSQWNS
barrier();
Value src = op.src();
Value dst = op.result();
auto srcTy = src.getType().cast<RankedTensorType>();
@@ -4698,7 +4695,8 @@ public:
decomposeInsertSliceAsyncOp(mod);
Allocation allocation(mod);
MembarAnalysis membar(&allocation);
MembarAnalysis membarPass(&allocation);
membarPass.run();
RewritePatternSet scf_patterns(context);
mlir::populateLoopToStdConversionPatterns(scf_patterns);