[Triton-MLIR][Backend] Fix many problems to get the pipeline working (#809)

1. Rewrite code generation of insert_slice_async.
2. Correct the wrong index passed to extract_slice in pipeline.
3. Add a prologue in pipeline to wait for dangling cp.asyncs.  
4. Move scf to cf conversion inside TritonGPUToLLVM because we need to
perform membar before scf to cf. It shouldn't be a technical limitation
and could be improved by a more general membar analysis.
5. Use an attribute to memoize the shared memory size and support
dynamic shared memory.
6. Prevent the combine pass to reorder insert_slice and extract_slice
across async_wait

Co-authored-by: Superjomn <yanchunwei@outlook.com>
This commit is contained in:
Keren Zhou
2022-10-27 22:09:06 -07:00
committed by GitHub
parent 42db3538e4
commit 3b80801dff
10 changed files with 122 additions and 70 deletions

View File

@@ -292,13 +292,11 @@ struct PTXCpAsyncWaitGroupInstr : public PTXCpAsyncInstrBase {
struct PTXCpAsyncLoadInstr : public PTXCpAsyncInstrBase { struct PTXCpAsyncLoadInstr : public PTXCpAsyncInstrBase {
explicit PTXCpAsyncLoadInstr(PTXBuilder *builder, explicit PTXCpAsyncLoadInstr(PTXBuilder *builder,
triton::CacheModifier modifier, triton::CacheModifier modifier)
triton::EvictionPolicy policy)
: PTXCpAsyncInstrBase(builder) { : PTXCpAsyncInstrBase(builder) {
o(triton::stringifyCacheModifier(modifier).str()); o(triton::stringifyCacheModifier(modifier).str());
o("shared"); o("shared");
o("global"); o("global");
o("L2::" + triton::stringifyEvictionPolicy(policy).str());
} }
}; };

View File

@@ -42,18 +42,14 @@ void MembarAnalysis::dfsOperation(Operation *operation,
void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo, void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
OpBuilder *builder) { OpBuilder *builder) {
if (op->getNumResults() < 1)
return;
if (isa<scf::ForOp>(op) || isa<scf::IfOp>(op) || isa<scf::YieldOp>(op) || if (isa<scf::ForOp>(op) || isa<scf::IfOp>(op) || isa<scf::YieldOp>(op) ||
isa<triton::gpu::ExtractSliceOp>(op) || isa<triton::gpu::ExtractSliceOp>(op) ||
isa<triton::gpu::InsertSliceAsyncOp>(op) ||
isa<triton::gpu::AllocTensorOp>(op)) { isa<triton::gpu::AllocTensorOp>(op)) {
// Do not insert barriers before control flow operations and // Do not insert barriers before control flow operations and
// alloc/extract/insert // alloc/extract/insert
// alloc is an allocation op without memory write. // alloc is an allocation op without memory write.
// In contrast, arith.constant is an allocation op with memory write. // In contrast, arith.constant is an allocation op with memory write.
// FIXME(Keren): extract and insert are always alias for now // FIXME(Keren): extract is always alias for now
return; return;
} }

View File

@@ -6,6 +6,7 @@
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/GPUDialect.h"
@@ -14,6 +15,7 @@
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "triton/Analysis/Allocation.h" #include "triton/Analysis/Allocation.h"
#include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/AxisInfo.h"
#include "triton/Analysis/Membar.h"
#include "triton/Analysis/Utility.h" #include "triton/Analysis/Utility.h"
#include "triton/Conversion/MLIRTypes.h" #include "triton/Conversion/MLIRTypes.h"
#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h" #include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h"
@@ -108,7 +110,7 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
#define i32_ty rewriter.getIntegerType(32) #define i32_ty rewriter.getIntegerType(32)
#define f32_ty rewriter.getF32Type() #define f32_ty rewriter.getF32Type()
#define vec_ty(type, num) VectorType::get(num, type) #define vec_ty(type, num) VectorType::get(num, type)
#define void_ty LLVM::LLVMVoidType::get(ctx) #define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
#define struct_ty(...) LLVM::LLVMStructType::getLiteral(__VA_ARGS__) #define struct_ty(...) LLVM::LLVMStructType::getLiteral(__VA_ARGS__)
// Creator for constant // Creator for constant
@@ -360,7 +362,7 @@ Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
auto *ptrOpr = builder.newAddrOperand(ptr, "r"); auto *ptrOpr = builder.newAddrOperand(ptr, "r");
auto *valOpr = builder.newOperand(val, c); auto *valOpr = builder.newOperand(val, c);
st(ptrOpr, valOpr).predicate(pred, "b"); st(ptrOpr, valOpr).predicate(pred, "b");
return builder.launch(rewriter, loc, void_ty); return builder.launch(rewriter, loc, void_ty(ctx));
} }
struct ConvertTritonGPUOpToLLVMPatternBase { struct ConvertTritonGPUOpToLLVMPatternBase {
@@ -1151,7 +1153,7 @@ struct StoreOpConversion
llvm::SmallVector<Type> argTys({boolTy, ptr.getType()}); llvm::SmallVector<Type> argTys({boolTy, ptr.getType()});
argTys.insert(argTys.end(), nWords, valArgTy); argTys.insert(argTys.end(), nWords, valArgTy);
auto ASMReturnTy = LLVM::LLVMVoidType::get(ctx); auto ASMReturnTy = void_ty(ctx);
ptxBuilder.launch(rewriter, loc, ASMReturnTy); ptxBuilder.launch(rewriter, loc, ASMReturnTy);
} }
@@ -3869,7 +3871,7 @@ struct AsyncWaitOpConversion
auto ctx = op.getContext(); auto ctx = op.getContext();
auto loc = op.getLoc(); auto loc = op.getLoc();
auto voidTy = LLVM::LLVMVoidType::get(ctx); auto voidTy = void_ty(ctx);
auto ret = ptxBuilder.launch(rewriter, loc, voidTy); auto ret = ptxBuilder.launch(rewriter, loc, voidTy);
// Safe to remove the op since it doesn't have any return value. // Safe to remove the op since it doesn't have any return value.
@@ -3907,7 +3909,7 @@ struct InsertSliceAsyncOpConversion
auto srcTy = src.getType().cast<RankedTensorType>(); auto srcTy = src.getType().cast<RankedTensorType>();
auto resTy = dst.getType().cast<RankedTensorType>(); auto resTy = dst.getType().cast<RankedTensorType>();
auto resElemTy = resTy.getElementType(); auto resElemTy = getTypeConverter()->convertType(resTy.getElementType());
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>(); auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
auto resSharedLayout = resTy.getEncoding().cast<SharedEncodingAttr>(); auto resSharedLayout = resTy.getEncoding().cast<SharedEncodingAttr>();
auto srcShape = srcTy.getShape(); auto srcShape = srcTy.getShape();
@@ -3928,7 +3930,7 @@ struct InsertSliceAsyncOpConversion
assert(axis == 0 && "insert_slice_async: Only axis=0 is supported for now"); assert(axis == 0 && "insert_slice_async: Only axis=0 is supported for now");
auto dstBase = createIndexAttrConstant(rewriter, loc, auto dstBase = createIndexAttrConstant(rewriter, loc,
getTypeConverter()->getIndexType(), getTypeConverter()->getIndexType(),
product<int64_t>(resTy.getShape())); product<int64_t>(srcTy.getShape()));
Value offset = mul(llIndex, dstBase); Value offset = mul(llIndex, dstBase);
auto dstPtrTy = LLVM::LLVMPointerType::get( auto dstPtrTy = LLVM::LLVMPointerType::get(
getTypeConverter()->convertType(resTy.getElementType()), 3); getTypeConverter()->convertType(resTy.getElementType()), 3);
@@ -4039,40 +4041,42 @@ struct InsertSliceAsyncOpConversion
auto numWordElems = bitWidth / resElemTy.getIntOrFloatBitWidth(); auto numWordElems = bitWidth / resElemTy.getIntOrFloatBitWidth();
// XXX(Keren): Tune CG and CA here. // XXX(Keren): Tune CG and CA here.
auto byteWidth = bitWidth / 8;
CacheModifier srcCacheModifier = CacheModifier srcCacheModifier =
bitWidth == 128 ? CacheModifier::CG : CacheModifier::CA; byteWidth == 16 ? CacheModifier::CG : CacheModifier::CA;
assert(bitWidth == 128 || bitWidth == 64 || bitWidth == 32); assert(byteWidth == 16 || byteWidth == 8 || byteWidth == 4);
auto resByteWidth = resElemTy.getIntOrFloatBitWidth() / 8;
for (int wordIdx = 0; wordIdx < numWords; ++wordIdx) {
PTXBuilder ptxBuilder;
auto &copyAsyncOp = *ptxBuilder.create<PTXCpAsyncLoadInstr>(
srcCacheModifier, op.evict());
auto tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}]; auto tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}];
auto *dstOperand = for (unsigned wordIdx = 0; wordIdx < numWords; ++wordIdx) {
ptxBuilder.newAddrOperand(tileOffset, "r", baseOffset); PTXBuilder ptxBuilder;
auto *srcOperand = ptxBuilder.newAddrOperand(srcElems[vecIdx], "l"); auto wordElemIdx = wordIdx * numWordElems;
auto *copySize = ptxBuilder.newConstantOperand(bitWidth); auto &copyAsyncOp =
*ptxBuilder.create<PTXCpAsyncLoadInstr>(srcCacheModifier);
auto *dstOperand = ptxBuilder.newAddrOperand(
tileOffset, "r", (wordElemIdx + baseOffset) * resByteWidth);
auto *srcOperand =
ptxBuilder.newAddrOperand(srcElems[elemIdx + wordElemIdx], "l");
auto *copySize = ptxBuilder.newConstantOperand(byteWidth);
auto *srcSize = copySize; auto *srcSize = copySize;
if (op.mask()) { if (op.mask()) {
// We don't use predicate in this case, setting src-size to 0 // We don't use predicate in this case, setting src-size to 0
// if there's any mask. cp.async will automatically fill the // if there's any mask. cp.async will automatically fill the
// remaining slots with 0 if cp-size > src-size. // remaining slots with 0 if cp-size > src-size.
// XXX(Keren): Always assume other = 0 for now. // XXX(Keren): Always assume other = 0 for now.
auto selectOp = select(maskElems[vecIdx + wordIdx * numWordElems], auto selectOp = select(maskElems[elemIdx + wordElemIdx],
i32_val(bitWidth), i32_val(0)); i32_val(byteWidth), i32_val(0));
srcSize = ptxBuilder.newOperand(selectOp, "r"); srcSize = ptxBuilder.newOperand(selectOp, "r");
} }
copyAsyncOp(dstOperand, srcOperand, copySize, srcSize); copyAsyncOp(dstOperand, srcOperand, copySize, srcSize);
ptxBuilder.launch(rewriter, loc, LLVM::LLVMVoidType::get(getContext())); ptxBuilder.launch(rewriter, loc, void_ty(getContext()));
} }
} }
PTXBuilder ptxBuilder; PTXBuilder ptxBuilder;
ptxBuilder.create<PTXCpAsyncCommitGroupInstr>()->operator()(); ptxBuilder.create<PTXCpAsyncCommitGroupInstr>()->operator()();
auto ret = ptxBuilder.launch(rewriter, loc, void_ty(getContext()));
ptxBuilder.launch(rewriter, loc, LLVM::LLVMVoidType::get(getContext())); rewriter.replaceOp(op, llDst);
rewriter.replaceOp(op, ret);
return success(); return success();
} }
}; };
@@ -4172,21 +4176,39 @@ public:
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
// step 1: Convert FuncOp to LLVMFuncOp via partial conversion // step 1: Allocate shared memories and insert barriers
// step 2: Allocate for shared memories // setp 2: Convert SCF to CFG
// step 3: Convert the rest of ops via partial conversion // step 3: Convert FuncOp to LLVMFuncOp via partial conversion
// The reason for a seperation between 1/3 is that, step 2 is out of // step 4: Convert the rest of ops via partial conversion
// The reason for putting step 1 before step 2 is that the membar analysis
// currently only supports SCF but not CFG.
// The reason for a seperation between 1/4 is that, step 3 is out of
// the scope of Dialect Conversion, thus we need to make sure the smem // the scope of Dialect Conversion, thus we need to make sure the smem
// is not revised during the conversion of step 3. // is not revised during the conversion of step 4.
Allocation allocation(mod);
MembarAnalysis membar(&allocation);
RewritePatternSet scf_patterns(context);
mlir::populateLoopToStdConversionPatterns(scf_patterns);
mlir::ConversionTarget scf_target(*context);
scf_target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp,
scf::WhileOp, scf::ExecuteRegionOp>();
scf_target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
if (failed(
applyPartialConversion(mod, scf_target, std::move(scf_patterns))))
return signalPassFailure();
RewritePatternSet func_patterns(context); RewritePatternSet func_patterns(context);
func_patterns.add<FuncOpConversion>(typeConverter, numWarps, 1 /*benefit*/); func_patterns.add<FuncOpConversion>(typeConverter, numWarps, 1 /*benefit*/);
if (failed( if (failed(
applyPartialConversion(mod, funcTarget, std::move(func_patterns)))) applyPartialConversion(mod, funcTarget, std::move(func_patterns))))
return signalPassFailure(); return signalPassFailure();
Allocation allocation(mod);
auto axisAnalysis = runAxisAnalysis(mod); auto axisAnalysis = runAxisAnalysis(mod);
initSharedMemory(allocation.getSharedMemorySize(), typeConverter); initSharedMemory(allocation.getSharedMemorySize(), typeConverter);
mod->setAttr("triton_gpu.shared",
mlir::IntegerAttr::get(mlir::IntegerType::get(context, 32),
allocation.getSharedMemorySize()));
// We set a higher benefit here to ensure triton's patterns runs before // We set a higher benefit here to ensure triton's patterns runs before
// arith patterns for some encoding not supported by the community // arith patterns for some encoding not supported by the community
@@ -4229,9 +4251,11 @@ void ConvertTritonGPUToLLVM::initSharedMemory(
OpBuilder b(mod.getBodyRegion()); OpBuilder b(mod.getBodyRegion());
auto loc = mod.getLoc(); auto loc = mod.getLoc();
auto elemTy = typeConverter.convertType(b.getIntegerType(8)); auto elemTy = typeConverter.convertType(b.getIntegerType(8));
auto arrayTy = LLVM::LLVMArrayType::get(elemTy, size); // Set array size 0 and external linkage indicates that we use dynamic shared
// allocation to allow a larger shared memory size for each kernel.
auto arrayTy = LLVM::LLVMArrayType::get(elemTy, 0);
auto global = b.create<LLVM::GlobalOp>( auto global = b.create<LLVM::GlobalOp>(
loc, arrayTy, /*isConstant=*/false, LLVM::Linkage::Internal, loc, arrayTy, /*isConstant=*/false, LLVM::Linkage::External,
"global_smem", /*value=*/Attribute(), "global_smem", /*value=*/Attribute(),
/*alignment=*/0, mlir::gpu::GPUDialect::getWorkgroupAddressSpace()); /*alignment=*/0, mlir::gpu::GPUDialect::getWorkgroupAddressSpace());
SmallVector<LLVM::LLVMFuncOp> funcs; SmallVector<LLVM::LLVMFuncOp> funcs;

View File

@@ -103,16 +103,21 @@ public:
if (!arg) if (!arg)
return mlir::failure(); return mlir::failure();
// cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2) // cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2)
// cvt(insert_slice(x), type2) -> extract_slice(cvt(x, type2))
auto alloc_tensor = dyn_cast<triton::gpu::AllocTensorOp>(arg); auto alloc_tensor = dyn_cast<triton::gpu::AllocTensorOp>(arg);
if (alloc_tensor) { if (alloc_tensor) {
rewriter.replaceOpWithNewOp<triton::gpu::AllocTensorOp>( rewriter.replaceOpWithNewOp<triton::gpu::AllocTensorOp>(
op, op->getResult(0).getType()); op, op->getResult(0).getType());
return mlir::success(); return mlir::success();
} }
// cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2))
auto insert_slice = dyn_cast<triton::gpu::InsertSliceAsyncOp>(arg); auto insert_slice = dyn_cast<triton::gpu::InsertSliceAsyncOp>(arg);
if (insert_slice) { if (insert_slice) {
auto newType = op->getResult(0).getType(); auto newType = op->getResult(0).getType();
// Ensure that the new insert_slice op is placed in the same place as the
// old insert_slice op. Otherwise, the new insert_slice op may be placed
// after the async_wait op, which is not allowed.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(insert_slice);
auto new_arg = rewriter.create<triton::gpu::ConvertLayoutOp>( auto new_arg = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newType, insert_slice.dst()); op->getLoc(), newType, insert_slice.dst());
rewriter.replaceOpWithNewOp<triton::gpu::InsertSliceAsyncOp>( rewriter.replaceOpWithNewOp<triton::gpu::InsertSliceAsyncOp>(
@@ -126,6 +131,11 @@ public:
auto extract_slice = dyn_cast<triton::gpu::ExtractSliceOp>(arg); auto extract_slice = dyn_cast<triton::gpu::ExtractSliceOp>(arg);
if (extract_slice) { if (extract_slice) {
auto origType = extract_slice.src().getType().cast<RankedTensorType>(); auto origType = extract_slice.src().getType().cast<RankedTensorType>();
// Ensure that the new extract_slice op is placed in the same place as the
// old extract_slice op. Otherwise, the new extract_slice op may be placed
// after the async_wait op, which is not allowed.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(extract_slice);
auto newType = RankedTensorType::get( auto newType = RankedTensorType::get(
origType.getShape(), origType.getElementType(), origType.getShape(), origType.getElementType(),
op->getResult(0).getType().cast<RankedTensorType>().getEncoding()); op->getResult(0).getType().cast<RankedTensorType>().getEncoding());

View File

@@ -78,6 +78,9 @@ public:
/// emit pipelined loads (before loop body) /// emit pipelined loads (before loop body)
void emitPrologue(); void emitPrologue();
/// emit pipelined loads (after loop body)
void emitEpilogue();
/// create the new ForOp (add new args & insert prefetched ops) /// create the new ForOp (add new args & insert prefetched ops)
scf::ForOp createNewForOp(); scf::ForOp createNewForOp();
@@ -362,6 +365,23 @@ void LoopPipeliner::emitPrologue() {
loadStageBuffer[loadOp][numStages - 1], loopIterIdx, /*axis*/ 0); loadStageBuffer[loadOp][numStages - 1], loopIterIdx, /*axis*/ 0);
loadsExtract[loadOp] = extractSlice; loadsExtract[loadOp] = extractSlice;
} }
// bump up loopIterIdx, this is used for getting the correct slice for the
// *next* iteration
loopIterIdx = builder.create<arith::AddIOp>(
loopIterIdx.getLoc(), loopIterIdx,
builder.create<arith::ConstantIntOp>(loopIterIdx.getLoc(), 1, 32));
}
void LoopPipeliner::emitEpilogue() {
// If there's any outstanding async copies, we need to wait for them.
// TODO(Keren): We may want to completely avoid the async copies in the last
// few iterations by setting is_masked attribute to true. We don't want to use
// the mask operand because it's a tensor but not a scalar.
OpBuilder builder(forOp);
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointAfter(forOp);
Operation *asyncWait =
builder.create<triton::gpu::AsyncWaitOp>(forOp.getLoc(), 0);
} }
scf::ForOp LoopPipeliner::createNewForOp() { scf::ForOp LoopPipeliner::createNewForOp() {
@@ -581,6 +601,8 @@ struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
scf::ForOp newForOp = pipeliner.createNewForOp(); scf::ForOp newForOp = pipeliner.createNewForOp();
pipeliner.emitEpilogue();
// replace the original loop // replace the original loop
for (unsigned i = 0; i < forOp->getNumResults(); ++i) for (unsigned i = 0; i < forOp->getNumResults(); ++i)
forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i)); forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i));

View File

@@ -136,10 +136,12 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
/*printAfterOnlyOnChange=*/true, /*printAfterOnlyOnChange=*/true,
/*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags); /*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags);
pm.addPass(mlir::createLowerToCFGPass());
pm.addPass(createConvertTritonGPUToLLVMPass()); pm.addPass(createConvertTritonGPUToLLVMPass());
// Conanicalize to eliminate the remaining UnrealizedConversionCastOp // Conanicalize to eliminate the remaining UnrealizedConversionCastOp
pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass()); // Simplify the IR to improve readability.
pm.addPass(mlir::createSymbolDCEPass());
pm.addPass(mlir::createCanonicalizerPass());
if (failed(pm.run(module))) { if (failed(pm.run(module))) {
llvm::errs() << "Pass execution failed"; llvm::errs() << "Pass execution failed";

View File

@@ -1257,8 +1257,8 @@ void init_triton_translation(py::module &m) {
using ret = py::return_value_policy; using ret = py::return_value_policy;
m.def("get_shared_memory_size", [](mlir::ModuleOp module) { m.def("get_shared_memory_size", [](mlir::ModuleOp module) {
auto pass = std::make_unique<mlir::Allocation>(module); return module->getAttrOfType<mlir::IntegerAttr>("triton_gpu.shared")
return pass->getSharedMemorySize(); .getInt();
}); });
m.def( m.def(

View File

@@ -875,7 +875,7 @@ def optimize_tritongpu_ir(mod, num_stages):
pm.enable_debug() pm.enable_debug()
# Get error in backend due to wrong conversion in expanding async-related instruction. # Get error in backend due to wrong conversion in expanding async-related instruction.
# TODO[Superjomn]: Open it when fixed. # TODO[Superjomn]: Open it when fixed.
# pm.add_tritongpu_pipeline_pass(num_stages) pm.add_tritongpu_pipeline_pass(num_stages)
pm.add_canonicalizer_pass() pm.add_canonicalizer_pass()
pm.add_cse_pass() pm.add_cse_pass()
pm.add_coalesce_pass() pm.add_coalesce_pass()

View File

@@ -326,7 +326,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#shared0 = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> #shared0 = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} { module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.mlir.global internal @global_smem // CHECK: llvm.mlir.global external @global_smem
// CHECK-LABEL: basic_alloc_tensor // CHECK-LABEL: basic_alloc_tensor
func @basic_alloc_tensor() { func @basic_alloc_tensor() {
// CHECK: llvm.mlir.addressof @global_smem // CHECK: llvm.mlir.addressof @global_smem
@@ -343,7 +343,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#shared0 = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> #shared0 = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} { module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.mlir.global internal @global_smem // CHECK: llvm.mlir.global external @global_smem
// CHECK-LABEL: basic_extract_slice // CHECK-LABEL: basic_extract_slice
func @basic_extract_slice() { func @basic_extract_slice() {
// CHECK: %[[BASE0:.*]] = llvm.mlir.addressof @global_smem // CHECK: %[[BASE0:.*]] = llvm.mlir.addressof @global_smem
@@ -382,10 +382,10 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}> #slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}>
#slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}> #slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}>
#AL = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #AL = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#A = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 4, order = [1, 0]}> #A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} { module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_insert_slice_async_v4 // CHECK-LABEL: basic_insert_slice_async_v4
func @basic_insert_slice_async_v4(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) { func @basic_insert_slice_async_v4(%arg0: !tt.ptr<f32> {tt.divisibility = 8 : i32}) {
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1> %off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1>
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0> %off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0>
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2> %off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2>
@@ -404,9 +404,9 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%index = arith.constant 1 : i32 %index = arith.constant 1 : i32
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att // CHECK: llvm.inline_asm has_side_effects asm_dialect = att
// CHECK-SAME: cp.async.cg.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x80, 0x80 // CHECK-SAME: cp.async.cg.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x10, 0x10
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att // CHECK: llvm.inline_asm has_side_effects asm_dialect = att
// CHECK-SAME: cp.async.cg.shared.global.L2::evict_normal [ ${{.*}} + 8 ], [ ${{.*}} + 0 ], 0x80, 0x80 // CHECK-SAME: cp.async.cg.shared.global [ ${{.*}} + 16 ], [ ${{.*}} + 0 ], 0x10, 0x10
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att // CHECK: llvm.inline_asm has_side_effects asm_dialect = att
// CHECK-SAME: cp.async.commit_group // CHECK-SAME: cp.async.commit_group
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64x!tt.ptr<f32>, #AL> -> tensor<2x16x64xf32, #A> %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64x!tt.ptr<f32>, #AL> -> tensor<2x16x64xf32, #A>
@@ -445,13 +445,13 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%index = arith.constant 1 : i32 %index = arith.constant 1 : i32
// CHECK: llvm.inline_asm // CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20 // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.inline_asm // CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20 // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.inline_asm // CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20 // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.inline_asm // CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20 // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.inline_asm // CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.commit_group // CHECK-SAME: cp.async.commit_group
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x32x!tt.ptr<f32>, #AL> -> tensor<2x16x32xf32, #A> %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x32x!tt.ptr<f32>, #AL> -> tensor<2x16x32xf32, #A>
@@ -489,21 +489,21 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%index = arith.constant 1 : i32 %index = arith.constant 1 : i32
// CHECK: llvm.inline_asm // CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20 // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.inline_asm // CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20 // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.inline_asm // CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20 // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.inline_asm // CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20 // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.inline_asm // CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 512 ], [ ${{.*}} + 0 ], 0x20, 0x20 // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.inline_asm // CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 512 ], [ ${{.*}} + 0 ], 0x20, 0x20 // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.inline_asm // CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 512 ], [ ${{.*}} + 0 ], 0x20, 0x20 // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.inline_asm // CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 512 ], [ ${{.*}} + 0 ], 0x20, 0x20 // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4
// CHECK: llvm.inline_asm // CHECK: llvm.inline_asm
// CHECK-SAME: cp.async.commit_group // CHECK-SAME: cp.async.commit_group
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32x!tt.ptr<f32>, #AL> -> tensor<2x32x32xf32, #A> %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32x!tt.ptr<f32>, #AL> -> tensor<2x32x32xf32, #A>
@@ -545,7 +545,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}> #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [0, 1]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [0, 1]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} { module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<1088 x i8> // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_blocked_blocked // CHECK-LABEL: convert_layout_blocked_blocked
func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) { func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) {
// CHECK: llvm.mlir.addressof @global_smem // CHECK: llvm.mlir.addressof @global_smem
@@ -593,7 +593,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}> #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 2], warpsPerCTA = [1, 1], order = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 2], warpsPerCTA = [1, 1], order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} { module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<1280 x i8> // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_blocked_blocked_vec // CHECK-LABEL: convert_layout_blocked_blocked_vec
func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) { func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) {
// CHECK: llvm.mlir.addressof @global_smem // CHECK: llvm.mlir.addressof @global_smem
@@ -617,7 +617,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}> #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} { module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<640 x i8> // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_blocked_blocked_multi_rep // CHECK-LABEL: convert_layout_blocked_blocked_multi_rep
func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) { func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) {
// CHECK: llvm.mlir.addressof @global_smem // CHECK: llvm.mlir.addressof @global_smem
@@ -682,7 +682,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
#mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 2]}> #mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 2]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} { module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<2560 x i8> // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_mma_block // CHECK-LABEL: convert_layout_mma_block
func @convert_layout_mma_blocked(%arg0: tensor<32x16xf32, #mma>) { func @convert_layout_mma_blocked(%arg0: tensor<32x16xf32, #mma>) {
// CHECK: nvvm.barrier0 // CHECK: nvvm.barrier0
@@ -703,7 +703,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}> #shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} { module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<16384 x i8> // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_blocked_shared // CHECK-LABEL: convert_layout_blocked_shared
func @convert_layout_blocked_shared(%arg0: tensor<128x32xf32, #blocked0>) { func @convert_layout_blocked_shared(%arg0: tensor<128x32xf32, #blocked0>) {
// CHECK: llvm.store // CHECK: llvm.store

View File

@@ -22,7 +22,7 @@
// CHECK: triton_gpu.async_wait {num = 2 : i32} // CHECK: triton_gpu.async_wait {num = 2 : i32}
// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]], %[[CONSTANT_0]] // CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]], %[[CONSTANT_0]]
// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]], %[[CONSTANT_0]] // CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]], %[[CONSTANT_0]]
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] // CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]]
// CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}} // CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}}
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]] // CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]] // CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
@@ -78,7 +78,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
// CHECK: triton_gpu.async_wait {num = 2 : i32} // CHECK: triton_gpu.async_wait {num = 2 : i32}
// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]], %[[CONSTANT_0]] // CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]], %[[CONSTANT_0]]
// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]], %[[CONSTANT_0]] // CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]], %[[CONSTANT_0]]
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] // CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]]
// CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}} // CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}}
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]] // CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]] // CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
@@ -131,7 +131,7 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f
// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] // CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]]
// CHECK: triton_gpu.async_wait {num = 1 : i32} // CHECK: triton_gpu.async_wait {num = 1 : i32}
// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]], %[[CONSTANT_0]] // CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]], %[[CONSTANT_0]]
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] // CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]]
// CHECK: tt.dot {{.*}}, %[[arg_b0]], {{.*}} // CHECK: tt.dot {{.*}}, %[[arg_b0]], {{.*}}
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]] // CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]] // CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]