[Triton-MLIR][BACKEND] Fix the membar pass to add missing barriers caused by scf.for (#933)
1. Add missing barriers and revert the previous temporary solution 2. Extract the `run` method from membar analysis because the membar analysis should have two phases, including construction, which doesn't modify any IR, and modification, which adds barrier IRs. Hope this could make the use of membar clear.
This commit is contained in:
@@ -29,7 +29,11 @@ public:
|
|||||||
/// The following circumstances are not considered yet:
|
/// The following circumstances are not considered yet:
|
||||||
/// - Double buffers
|
/// - Double buffers
|
||||||
/// - N buffers
|
/// - N buffers
|
||||||
MembarAnalysis(Allocation *allocation) : allocation(allocation) { run(); }
|
MembarAnalysis(Allocation *allocation) : allocation(allocation) {}
|
||||||
|
|
||||||
|
/// Runs the membar analysis to the given operation, inserts a barrier if
|
||||||
|
/// necessary.
|
||||||
|
void run();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
struct RegionInfo {
|
struct RegionInfo {
|
||||||
@@ -82,10 +86,6 @@ private:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Runs the membar analysis to the given operation, inserts a barrier if
|
|
||||||
/// necessary.
|
|
||||||
void run();
|
|
||||||
|
|
||||||
/// Applies the barrier analysis based on the SCF dialect, in which each
|
/// Applies the barrier analysis based on the SCF dialect, in which each
|
||||||
/// region has a single basic block only.
|
/// region has a single basic block only.
|
||||||
/// Example:
|
/// Example:
|
||||||
|
@@ -24,21 +24,43 @@ void MembarAnalysis::dfsOperation(Operation *operation,
|
|||||||
// scf.if only: two regions
|
// scf.if only: two regions
|
||||||
// scf.for: one region
|
// scf.for: one region
|
||||||
RegionInfo curRegionInfo;
|
RegionInfo curRegionInfo;
|
||||||
for (auto ®ion : operation->getRegions()) {
|
auto traverseRegions = [&]() -> auto{
|
||||||
// Copy the parent info as the current info.
|
for (auto ®ion : operation->getRegions()) {
|
||||||
RegionInfo regionInfo = *parentRegionInfo;
|
// Copy the parent info as the current info.
|
||||||
for (auto &block : region.getBlocks()) {
|
RegionInfo regionInfo = *parentRegionInfo;
|
||||||
assert(region.getBlocks().size() == 1 &&
|
for (auto &block : region.getBlocks()) {
|
||||||
"Multiple blocks in a region is not supported");
|
assert(region.getBlocks().size() == 1 &&
|
||||||
for (auto &op : block.getOperations()) {
|
"Multiple blocks in a region is not supported");
|
||||||
// Traverse the nested operation.
|
for (auto &op : block.getOperations()) {
|
||||||
dfsOperation(&op, ®ionInfo, builder);
|
// Traverse the nested operation.
|
||||||
|
dfsOperation(&op, ®ionInfo, builder);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
curRegionInfo.join(regionInfo);
|
||||||
}
|
}
|
||||||
curRegionInfo.join(regionInfo);
|
// Set the parent region info as the union of the nested region info.
|
||||||
|
*parentRegionInfo = curRegionInfo;
|
||||||
|
};
|
||||||
|
|
||||||
|
traverseRegions();
|
||||||
|
if (isa<scf::ForOp>(operation)) {
|
||||||
|
// scf.for can have two possible inputs: the init value and the
|
||||||
|
// previous iteration's result. Although we've applied alias analysis,
|
||||||
|
// there could be unsynced memory accesses on reused memories.
|
||||||
|
// For example, consider the following code:
|
||||||
|
// %1 = convert_layout %0: blocked -> shared
|
||||||
|
// ...
|
||||||
|
// gpu.barrier
|
||||||
|
// ...
|
||||||
|
// %5 = convert_layout %4 : shared -> dot
|
||||||
|
// %6 = tt.dot %2, %5
|
||||||
|
// scf.yield
|
||||||
|
//
|
||||||
|
// Though %5 could be released before scf.yield, it may shared the same
|
||||||
|
// memory with %1. So we actually have to insert a barrier before %1 to
|
||||||
|
// make sure the memory is synced.
|
||||||
|
traverseRegions();
|
||||||
}
|
}
|
||||||
// Set the parent region info as the union of the nested region info.
|
|
||||||
*parentRegionInfo = curRegionInfo;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -49,8 +71,7 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
|
|||||||
// Do not insert barriers before control flow operations and
|
// Do not insert barriers before control flow operations and
|
||||||
// alloc/extract/insert
|
// alloc/extract/insert
|
||||||
// alloc is an allocation op without memory write.
|
// alloc is an allocation op without memory write.
|
||||||
// In contrast, arith.constant is an allocation op with memory write.
|
// FIXME(Keren): extract_slice is always alias for now
|
||||||
// FIXME(Keren): extract is always alias for now
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -644,7 +644,6 @@ public:
|
|||||||
return multiDimIdx;
|
return multiDimIdx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
struct SmallVectorKeyInfo {
|
struct SmallVectorKeyInfo {
|
||||||
static unsigned getHashValue(const SmallVector<unsigned> &key) {
|
static unsigned getHashValue(const SmallVector<unsigned> &key) {
|
||||||
return llvm::hash_combine_range(key.begin(), key.end());
|
return llvm::hash_combine_range(key.begin(), key.end());
|
||||||
@@ -672,7 +671,7 @@ public:
|
|||||||
emitIndices(loc, rewriter, parent, sliceLayout.paddedShape(shape));
|
emitIndices(loc, rewriter, parent, sliceLayout.paddedShape(shape));
|
||||||
unsigned numIndices = parentIndices.size();
|
unsigned numIndices = parentIndices.size();
|
||||||
SmallVector<SmallVector<Value>> resultIndices;
|
SmallVector<SmallVector<Value>> resultIndices;
|
||||||
for (unsigned i = 0; i < numIndices; ++i){
|
for (unsigned i = 0; i < numIndices; ++i) {
|
||||||
SmallVector<Value> indices = parentIndices[i];
|
SmallVector<Value> indices = parentIndices[i];
|
||||||
indices.erase(indices.begin() + dim);
|
indices.erase(indices.begin() + dim);
|
||||||
resultIndices.push_back(indices);
|
resultIndices.push_back(indices);
|
||||||
@@ -1203,14 +1202,14 @@ struct BroadcastOpConversion
|
|||||||
auto resultOffsets = emitOffsetForLayout(resultLayout, resultShape);
|
auto resultOffsets = emitOffsetForLayout(resultLayout, resultShape);
|
||||||
SmallVector<Value> srcVals = getElementsFromStruct(loc, src, rewriter);
|
SmallVector<Value> srcVals = getElementsFromStruct(loc, src, rewriter);
|
||||||
DenseMap<SmallVector<unsigned>, Value, SmallVectorKeyInfo> srcValues;
|
DenseMap<SmallVector<unsigned>, Value, SmallVectorKeyInfo> srcValues;
|
||||||
for(size_t i = 0; i < srcOffsets.size(); i++){
|
for (size_t i = 0; i < srcOffsets.size(); i++) {
|
||||||
srcValues[srcOffsets[i]] = srcVals[i];
|
srcValues[srcOffsets[i]] = srcVals[i];
|
||||||
}
|
}
|
||||||
SmallVector<Value> resultVals;
|
SmallVector<Value> resultVals;
|
||||||
for(size_t i = 0; i < resultOffsets.size(); i++) {
|
for (size_t i = 0; i < resultOffsets.size(); i++) {
|
||||||
auto offset = resultOffsets[i];
|
auto offset = resultOffsets[i];
|
||||||
for(size_t j = 0; j < srcShape.size(); j++)
|
for (size_t j = 0; j < srcShape.size(); j++)
|
||||||
if(srcShape[j]==1)
|
if (srcShape[j] == 1)
|
||||||
offset[j] = 0;
|
offset[j] = 0;
|
||||||
resultVals.push_back(srcValues.lookup(offset));
|
resultVals.push_back(srcValues.lookup(offset));
|
||||||
}
|
}
|
||||||
@@ -1940,8 +1939,8 @@ struct MakeRangeOpConversion
|
|||||||
unsigned elems = idxs.size();
|
unsigned elems = idxs.size();
|
||||||
SmallVector<Value> retVals(elems);
|
SmallVector<Value> retVals(elems);
|
||||||
// TODO: slice layout has more elements than expected.
|
// TODO: slice layout has more elements than expected.
|
||||||
// Unexpected behavior for make range, but genereally ok when followed by expand dims + broadcast.
|
// Unexpected behavior for make range, but genereally ok when followed by
|
||||||
// very weird behavior otherwise potentially.
|
// expand dims + broadcast. very weird behavior otherwise potentially.
|
||||||
for (const auto multiDim : llvm::enumerate(idxs)) {
|
for (const auto multiDim : llvm::enumerate(idxs)) {
|
||||||
assert(multiDim.value().size() == 1);
|
assert(multiDim.value().size() == 1);
|
||||||
retVals[multiDim.index()] = add(multiDim.value()[0], start);
|
retVals[multiDim.index()] = add(multiDim.value()[0], start);
|
||||||
@@ -2647,13 +2646,13 @@ public:
|
|||||||
}
|
}
|
||||||
// dot_op<opIdx=0, parent=#mma> = #mma
|
// dot_op<opIdx=0, parent=#mma> = #mma
|
||||||
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
|
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
|
||||||
if(srcLayout.isa<MmaEncodingAttr>() &&
|
if (srcLayout.isa<MmaEncodingAttr>() &&
|
||||||
dstLayout.isa<DotOperandEncodingAttr>()) {
|
dstLayout.isa<DotOperandEncodingAttr>()) {
|
||||||
auto srcMmaLayout = srcLayout.cast<MmaEncodingAttr>();
|
auto srcMmaLayout = srcLayout.cast<MmaEncodingAttr>();
|
||||||
auto dstDotLayout = dstLayout.cast<DotOperandEncodingAttr>();
|
auto dstDotLayout = dstLayout.cast<DotOperandEncodingAttr>();
|
||||||
if(srcMmaLayout.getWarpsPerCTA()[1] == 1 &&
|
if (srcMmaLayout.getWarpsPerCTA()[1] == 1 &&
|
||||||
dstDotLayout.getOpIdx() == 0 &&
|
dstDotLayout.getOpIdx() == 0 &&
|
||||||
dstDotLayout.getParent() == srcMmaLayout) {
|
dstDotLayout.getParent() == srcMmaLayout) {
|
||||||
// get source values
|
// get source values
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
|
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
|
||||||
@@ -2662,14 +2661,15 @@ public:
|
|||||||
this->getTypeConverter()->convertType(srcTy.getElementType());
|
this->getTypeConverter()->convertType(srcTy.getElementType());
|
||||||
// for the destination type, we need to pack values together
|
// for the destination type, we need to pack values together
|
||||||
// so they can be consumed by tensor core operations
|
// so they can be consumed by tensor core operations
|
||||||
unsigned vecSize = std::max<unsigned>(32 / elemTy.getIntOrFloatBitWidth(), 1);
|
unsigned vecSize =
|
||||||
|
std::max<unsigned>(32 / elemTy.getIntOrFloatBitWidth(), 1);
|
||||||
Type vecTy = vec_ty(elemTy, vecSize);
|
Type vecTy = vec_ty(elemTy, vecSize);
|
||||||
SmallVector<Type> types(elems/vecSize, vecTy);
|
SmallVector<Type> types(elems / vecSize, vecTy);
|
||||||
SmallVector<Value> vecVals;
|
SmallVector<Value> vecVals;
|
||||||
for(unsigned i = 0; i < elems; i += vecSize) {
|
for (unsigned i = 0; i < elems; i += vecSize) {
|
||||||
Value packed = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
Value packed = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
||||||
for(unsigned j = 0; j < vecSize; j++)
|
for (unsigned j = 0; j < vecSize; j++)
|
||||||
packed = insert_element(vecTy, packed, vals[i+j], i32_val(j));
|
packed = insert_element(vecTy, packed, vals[i + j], i32_val(j));
|
||||||
vecVals.push_back(packed);
|
vecVals.push_back(packed);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2679,18 +2679,19 @@ public:
|
|||||||
// implicitly depends on how emitOffsetsForMMAV2
|
// implicitly depends on how emitOffsetsForMMAV2
|
||||||
// is implemented
|
// is implemented
|
||||||
SmallVector<Value> reorderedVals;
|
SmallVector<Value> reorderedVals;
|
||||||
for(unsigned i = 0; i < vecVals.size(); i += 4) {
|
for (unsigned i = 0; i < vecVals.size(); i += 4) {
|
||||||
reorderedVals.push_back(vecVals[i]);
|
reorderedVals.push_back(vecVals[i]);
|
||||||
reorderedVals.push_back(vecVals[i+2]);
|
reorderedVals.push_back(vecVals[i + 2]);
|
||||||
reorderedVals.push_back(vecVals[i+1]);
|
reorderedVals.push_back(vecVals[i + 1]);
|
||||||
reorderedVals.push_back(vecVals[i+3]);
|
reorderedVals.push_back(vecVals[i + 3]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// return composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK);
|
// return composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK);
|
||||||
|
|
||||||
|
Type structTy =
|
||||||
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
|
LLVM::LLVMStructType::getLiteral(this->getContext(), types);
|
||||||
Value view = getStructFromElements(loc, reorderedVals, rewriter, structTy);
|
Value view =
|
||||||
|
getStructFromElements(loc, reorderedVals, rewriter, structTy);
|
||||||
rewriter.replaceOp(op, view);
|
rewriter.replaceOp(op, view);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -3138,10 +3139,6 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
|||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
|
|
||||||
// TODO[Keren]: A temporary workaround for an issue from membar pass.
|
|
||||||
// https://triton-lang.slack.com/archives/C042VBSQWNS/p1669796615860699?thread_ts=1669779203.526739&cid=C042VBSQWNS
|
|
||||||
barrier();
|
|
||||||
|
|
||||||
Value src = op.src();
|
Value src = op.src();
|
||||||
Value dst = op.result();
|
Value dst = op.result();
|
||||||
auto srcTy = src.getType().cast<RankedTensorType>();
|
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||||
@@ -4698,7 +4695,8 @@ public:
|
|||||||
decomposeInsertSliceAsyncOp(mod);
|
decomposeInsertSliceAsyncOp(mod);
|
||||||
|
|
||||||
Allocation allocation(mod);
|
Allocation allocation(mod);
|
||||||
MembarAnalysis membar(&allocation);
|
MembarAnalysis membarPass(&allocation);
|
||||||
|
membarPass.run();
|
||||||
|
|
||||||
RewritePatternSet scf_patterns(context);
|
RewritePatternSet scf_patterns(context);
|
||||||
mlir::populateLoopToStdConversionPatterns(scf_patterns);
|
mlir::populateLoopToStdConversionPatterns(scf_patterns);
|
||||||
|
@@ -50,22 +50,25 @@ public:
|
|||||||
auto dstType = convert.getType().cast<RankedTensorType>();
|
auto dstType = convert.getType().cast<RankedTensorType>();
|
||||||
if (srcType.getEncoding().isa<triton::gpu::BlockedEncodingAttr>() &&
|
if (srcType.getEncoding().isa<triton::gpu::BlockedEncodingAttr>() &&
|
||||||
dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>()) {
|
dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>()) {
|
||||||
auto dstDotOperand = dstType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>();
|
auto dstDotOperand =
|
||||||
|
dstType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>();
|
||||||
auto dstParent = dstDotOperand.getParent();
|
auto dstParent = dstDotOperand.getParent();
|
||||||
if(dstDotOperand.getOpIdx()==1 ||
|
if (dstDotOperand.getOpIdx() == 1 ||
|
||||||
!dstParent.isa<triton::gpu::MmaEncodingAttr>())
|
!dstParent.isa<triton::gpu::MmaEncodingAttr>())
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
auto dstParentMma = dstParent.cast<triton::gpu::MmaEncodingAttr>();
|
auto dstParentMma = dstParent.cast<triton::gpu::MmaEncodingAttr>();
|
||||||
if(dstParentMma.getVersion() == 1 ||
|
if (dstParentMma.getVersion() == 1 ||
|
||||||
dstParentMma.getWarpsPerCTA()[1] > 1)
|
dstParentMma.getWarpsPerCTA()[1] > 1)
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
SetVector<Operation*> bwdSlices;
|
SetVector<Operation *> bwdSlices;
|
||||||
mlir::getBackwardSlice(convert.getResult(), &bwdSlices);
|
mlir::getBackwardSlice(convert.getResult(), &bwdSlices);
|
||||||
if(llvm::find_if(bwdSlices, [](Operation *op) { return isa<triton::DotOp>(op); }) == bwdSlices.end())
|
if (llvm::find_if(bwdSlices, [](Operation *op) {
|
||||||
|
return isa<triton::DotOp>(op);
|
||||||
|
}) == bwdSlices.end())
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
|
|
||||||
auto tmpType =
|
auto tmpType = RankedTensorType::get(
|
||||||
RankedTensorType::get(dstType.getShape(), dstType.getElementType(), dstParentMma);
|
dstType.getShape(), dstType.getElementType(), dstParentMma);
|
||||||
auto tmp = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
auto tmp = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||||
convert.getLoc(), tmpType, convert.getOperand());
|
convert.getLoc(), tmpType, convert.getOperand());
|
||||||
auto newConvert = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
auto newConvert = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||||
@@ -601,10 +604,9 @@ mmaVersionToShapePerWarp(int version, const ArrayRef<int64_t> &shape,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
SmallVector<unsigned, 2> warpsPerTileV1(triton::DotOp dotOp,
|
SmallVector<unsigned, 2> warpsPerTileV1(triton::DotOp dotOp,
|
||||||
const ArrayRef<int64_t> shape,
|
const ArrayRef<int64_t> shape,
|
||||||
int numWarps) {
|
int numWarps) {
|
||||||
SmallVector<unsigned, 2> ret = {1, 1};
|
SmallVector<unsigned, 2> ret = {1, 1};
|
||||||
SmallVector<int64_t, 2> shapePerWarp =
|
SmallVector<int64_t, 2> shapePerWarp =
|
||||||
mmaVersionToShapePerWarp(1, shape, numWarps);
|
mmaVersionToShapePerWarp(1, shape, numWarps);
|
||||||
@@ -624,35 +626,37 @@ SmallVector<unsigned, 2> warpsPerTileV1(triton::DotOp dotOp,
|
|||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
|
SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
|
||||||
const ArrayRef<int64_t> shape,
|
const ArrayRef<int64_t> shape,
|
||||||
int numWarps) {
|
int numWarps) {
|
||||||
SetVector<Operation*> slices;
|
SetVector<Operation *> slices;
|
||||||
mlir::getForwardSlice(dotOp.getResult(), &slices);
|
mlir::getForwardSlice(dotOp.getResult(), &slices);
|
||||||
if(llvm::find_if(slices, [](Operation *op) { return isa<triton::DotOp>(op); }) != slices.end())
|
if (llvm::find_if(slices, [](Operation *op) {
|
||||||
return {(unsigned)numWarps, 1};
|
return isa<triton::DotOp>(op);
|
||||||
|
}) != slices.end())
|
||||||
|
return {(unsigned)numWarps, 1};
|
||||||
|
|
||||||
SmallVector<unsigned, 2> ret = {1, 1};
|
SmallVector<unsigned, 2> ret = {1, 1};
|
||||||
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
|
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
// TODO (@daadaada): double-check.
|
// TODO (@daadaada): double-check.
|
||||||
// original logic in
|
// original logic in
|
||||||
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
|
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
|
||||||
// seems buggy for shape = [32, 16] ?
|
// seems buggy for shape = [32, 16] ?
|
||||||
do {
|
do {
|
||||||
changed = false;
|
changed = false;
|
||||||
if (ret[0] * ret[1] >= numWarps)
|
if (ret[0] * ret[1] >= numWarps)
|
||||||
break;
|
break;
|
||||||
if (shape[0] / shapePerWarp[0] / ret[0] >=
|
if (shape[0] / shapePerWarp[0] / ret[0] >=
|
||||||
shape[1] / (shapePerWarp[1] * 2) / ret[1]) {
|
shape[1] / (shapePerWarp[1] * 2) / ret[1]) {
|
||||||
if (ret[0] < shape[0] / shapePerWarp[0]) {
|
if (ret[0] < shape[0] / shapePerWarp[0]) {
|
||||||
ret[0] *= 2;
|
ret[0] *= 2;
|
||||||
} else
|
} else
|
||||||
ret[1] *= 2;
|
|
||||||
} else {
|
|
||||||
ret[1] *= 2;
|
ret[1] *= 2;
|
||||||
}
|
} else {
|
||||||
} while (true);
|
ret[1] *= 2;
|
||||||
return ret;
|
}
|
||||||
|
} while (true);
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@@ -133,7 +133,7 @@ LogicalResult Prefetcher::initialize() {
|
|||||||
|
|
||||||
// TODO: segfault (original for still has uses)
|
// TODO: segfault (original for still has uses)
|
||||||
// when used in flash attention that has 2 dots in the loop
|
// when used in flash attention that has 2 dots in the loop
|
||||||
if(dotsInFor.size() > 1)
|
if (dotsInFor.size() > 1)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// returns source of cvt
|
// returns source of cvt
|
||||||
|
@@ -1251,13 +1251,12 @@ void init_triton_ir(py::module &&m) {
|
|||||||
llvm::StringRef(prefix)),
|
llvm::StringRef(prefix)),
|
||||||
values);
|
values);
|
||||||
})
|
})
|
||||||
// Undef
|
// Undef
|
||||||
.def("create_undef",
|
.def("create_undef",
|
||||||
[](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value {
|
[](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value {
|
||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
return self.create<::mlir::LLVM::UndefOp>(loc, type);
|
return self.create<::mlir::LLVM::UndefOp>(loc, type);
|
||||||
})
|
});
|
||||||
;
|
|
||||||
|
|
||||||
py::class_<mlir::PassManager>(m, "pass_manager")
|
py::class_<mlir::PassManager>(m, "pass_manager")
|
||||||
.def(py::init<mlir::MLIRContext *>())
|
.def(py::init<mlir::MLIRContext *>())
|
||||||
|
@@ -261,9 +261,9 @@ func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B :
|
|||||||
%cst0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
%cst0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||||
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||||
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
|
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
|
||||||
%cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
%cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #AL>
|
||||||
// CHECK-NEXT: Membar 6
|
// CHECK-NEXT: Membar 6
|
||||||
%cst2 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
%cst2 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #AL>
|
||||||
scf.yield %c_shared, %a_shared, %b_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
|
scf.yield %c_shared, %a_shared, %b_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
|
||||||
}
|
}
|
||||||
// CHECK-NEXT: Membar 9
|
// CHECK-NEXT: Membar 9
|
||||||
@@ -271,4 +271,48 @@ func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B :
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Although cst2 is not an argument of scf.yield, its memory is reused by cst1.
|
||||||
|
// So we need a barrier both before and after cst1
|
||||||
|
// CHECK-LABEL: for_reuse
|
||||||
|
func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||||
|
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||||
|
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||||
|
// CHECK-NEXT: Membar 2
|
||||||
|
%cst0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||||
|
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||||
|
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
|
||||||
|
// CHECK-NEXT: Membar 5
|
||||||
|
%cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||||
|
// CHECK-NEXT: Membar 7
|
||||||
|
%cst2 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||||
|
scf.yield %c_shared, %a_shared, %b_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
|
||||||
|
}
|
||||||
|
// CHECK-NEXT: Membar 10
|
||||||
|
%cst3 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// CHECK-LABEL: for_reuse_nested
|
||||||
|
func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||||
|
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||||
|
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||||
|
// CHECK-NEXT: Membar 2
|
||||||
|
%cst0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||||
|
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||||
|
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
|
||||||
|
// CHECK-NEXT: Membar 5
|
||||||
|
%cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||||
|
%a_shared_next, %b_shared_next, %c_shared_next = scf.for %ivv = %lb to %ub step %step iter_args(%a_shared_nested = %a_shared_init, %b_shared_nested = %b_shared_init, %c_shared_nested = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
|
||||||
|
// CHECK-NEXT: Membar 7
|
||||||
|
%cst2 = tt.cat %a_shared_nested, %b_shared_nested {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||||
|
scf.yield %c_shared_nested, %a_shared_nested, %b_shared_nested : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
|
||||||
|
}
|
||||||
|
scf.yield %c_shared, %a_shared, %b_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
|
||||||
|
}
|
||||||
|
// CHECK-NEXT: Membar 11
|
||||||
|
%cst3 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -26,7 +26,9 @@ struct TestMembarPass
|
|||||||
auto op_name = SymbolTable::getSymbolName(operation).getValue().str();
|
auto op_name = SymbolTable::getSymbolName(operation).getValue().str();
|
||||||
os << op_name << "\n";
|
os << op_name << "\n";
|
||||||
Allocation allocation(operation);
|
Allocation allocation(operation);
|
||||||
MembarAnalysis analysis(&allocation);
|
MembarAnalysis membarPass(&allocation);
|
||||||
|
membarPass.run();
|
||||||
|
|
||||||
size_t operationId = 0;
|
size_t operationId = 0;
|
||||||
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
|
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
|
||||||
if (isa<gpu::BarrierOp>(op)) {
|
if (isa<gpu::BarrierOp>(op)) {
|
||||||
|
Reference in New Issue
Block a user